Unverified Commit 8f88cef1 authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Fix #1588 - torch compatability for <=2.4 (#1590)

parent d2fe0e3c
......@@ -19,7 +19,7 @@ else:
# Higher level op: int8 matmul + dequant + bias
torch.library.define(
"bitsandbytes::int8_scaled_mm",
"(Tensor A, Tensor B, Tensor row_stats, Tensor col_stats, Tensor? bias=None, ScalarType dtype=float16) -> Tensor",
"(Tensor A, Tensor B, Tensor row_stats, Tensor col_stats, Tensor? bias=None, ScalarType? dtype=None) -> Tensor",
)
......@@ -30,10 +30,10 @@ def _(
row_stats: torch.Tensor,
col_stats: torch.Tensor,
bias: Optional[torch.Tensor] = None,
dtype=torch.float16,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
shapeC = (*A.shape[:-1], B.shape[0])
return torch.empty(shapeC, device=A.device, dtype=dtype)
return torch.empty(shapeC, device=A.device, dtype=dtype or torch.float16)
torch.library.define(
......@@ -98,7 +98,7 @@ def _(A: torch.Tensor, stats: torch.Tensor) -> torch.Tensor:
# Default PyTorch-native implementation
@register_kernel("bitsandbytes::int8_vectorwise_dequant", None)
@register_kernel("bitsandbytes::int8_vectorwise_dequant", "default")
def _(A: torch.Tensor, stats: torch.Tensor):
# To dequantize we divide by 127, or multiply by the reciprocal.
return A * stats.view(-1, 1) * 7.874015718698502e-3
......@@ -106,7 +106,7 @@ def _(A: torch.Tensor, stats: torch.Tensor):
torch.library.define(
"bitsandbytes::int8_mm_dequant",
"(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType dtype=float16, Tensor? bias=None) -> Tensor",
"(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType? dtype=None, Tensor? bias=None) -> Tensor",
)
......@@ -115,11 +115,11 @@ def _(
A: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
dtype=torch.float16,
dtype: Optional[torch.dtype] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
torch._check(A.dtype == torch.int32, lambda: "A must be int32")
return torch.empty_like(A, dtype=dtype)
return torch.empty_like(A, dtype=dtype or torch.float16)
torch.library.define(
......
......@@ -28,7 +28,7 @@ def _(
A: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
dtype=torch.float16,
dtype: Optional[torch.dtype] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}")
......@@ -43,7 +43,7 @@ def _(
if bias is not None:
out += bias
return out.to(dtype)
return out.to(dtype or torch.float16)
@register_kernel("bitsandbytes::quantize_blockwise", "cpu")
......
......@@ -90,7 +90,7 @@ def _(
A: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
dtype=torch.float16,
dtype: Optional[torch.dtype] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}")
......@@ -121,7 +121,7 @@ def _(
if bias is not None and bias.dtype != torch.float16:
out.add_(bias)
return out.to(dtype)
return out.to(dtype or torch.float16)
@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda")
......
......@@ -5,26 +5,31 @@ import torch
from ..._ops import register_kernel
@register_kernel("bitsandbytes::int8_scaled_mm", None)
@register_kernel("bitsandbytes::int8_scaled_mm", "default")
def _(
A: torch.Tensor,
B: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
bias: Optional[torch.Tensor] = None,
dtype=torch.float16,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
out_i32 = torch.ops.bitsandbytes.int8_linear_matmul.default(A, B)
out = torch.ops.bitsandbytes.int8_mm_dequant.default(out_i32, row_stats, col_stats, dtype=dtype, bias=bias)
return out
return torch.ops.bitsandbytes.int8_mm_dequant.default(
out_i32,
row_stats,
col_stats,
dtype=dtype or torch.float16,
bias=bias,
)
@register_kernel("bitsandbytes::int8_linear_matmul", None)
@register_kernel("bitsandbytes::int8_linear_matmul", "default")
def _(A: torch.Tensor, B: torch.Tensor):
return _int8_linear_matmul_impl(A, B)
@register_kernel("bitsandbytes::int8_linear_matmul.out", None)
@register_kernel("bitsandbytes::int8_linear_matmul.out", "default")
def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
torch._check(out.dtype == torch.int32)
_int8_linear_matmul_impl(A, B, out)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment