Unverified Commit 8f88cef1 authored by Matthew Douglas's avatar Matthew Douglas Committed by GitHub
Browse files

Fix #1588 - torch compatability for <=2.4 (#1590)

parent d2fe0e3c
...@@ -19,7 +19,7 @@ else: ...@@ -19,7 +19,7 @@ else:
# Higher level op: int8 matmul + dequant + bias # Higher level op: int8 matmul + dequant + bias
torch.library.define( torch.library.define(
"bitsandbytes::int8_scaled_mm", "bitsandbytes::int8_scaled_mm",
"(Tensor A, Tensor B, Tensor row_stats, Tensor col_stats, Tensor? bias=None, ScalarType dtype=float16) -> Tensor", "(Tensor A, Tensor B, Tensor row_stats, Tensor col_stats, Tensor? bias=None, ScalarType? dtype=None) -> Tensor",
) )
...@@ -30,10 +30,10 @@ def _( ...@@ -30,10 +30,10 @@ def _(
row_stats: torch.Tensor, row_stats: torch.Tensor,
col_stats: torch.Tensor, col_stats: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
dtype=torch.float16, dtype: Optional[torch.dtype] = None,
) -> torch.Tensor: ) -> torch.Tensor:
shapeC = (*A.shape[:-1], B.shape[0]) shapeC = (*A.shape[:-1], B.shape[0])
return torch.empty(shapeC, device=A.device, dtype=dtype) return torch.empty(shapeC, device=A.device, dtype=dtype or torch.float16)
torch.library.define( torch.library.define(
...@@ -98,7 +98,7 @@ def _(A: torch.Tensor, stats: torch.Tensor) -> torch.Tensor: ...@@ -98,7 +98,7 @@ def _(A: torch.Tensor, stats: torch.Tensor) -> torch.Tensor:
# Default PyTorch-native implementation # Default PyTorch-native implementation
@register_kernel("bitsandbytes::int8_vectorwise_dequant", None) @register_kernel("bitsandbytes::int8_vectorwise_dequant", "default")
def _(A: torch.Tensor, stats: torch.Tensor): def _(A: torch.Tensor, stats: torch.Tensor):
# To dequantize we divide by 127, or multiply by the reciprocal. # To dequantize we divide by 127, or multiply by the reciprocal.
return A * stats.view(-1, 1) * 7.874015718698502e-3 return A * stats.view(-1, 1) * 7.874015718698502e-3
...@@ -106,7 +106,7 @@ def _(A: torch.Tensor, stats: torch.Tensor): ...@@ -106,7 +106,7 @@ def _(A: torch.Tensor, stats: torch.Tensor):
torch.library.define( torch.library.define(
"bitsandbytes::int8_mm_dequant", "bitsandbytes::int8_mm_dequant",
"(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType dtype=float16, Tensor? bias=None) -> Tensor", "(Tensor A, Tensor row_stats, Tensor col_stats, ScalarType? dtype=None, Tensor? bias=None) -> Tensor",
) )
...@@ -115,11 +115,11 @@ def _( ...@@ -115,11 +115,11 @@ def _(
A: torch.Tensor, A: torch.Tensor,
row_stats: torch.Tensor, row_stats: torch.Tensor,
col_stats: torch.Tensor, col_stats: torch.Tensor,
dtype=torch.float16, dtype: Optional[torch.dtype] = None,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
torch._check(A.dtype == torch.int32, lambda: "A must be int32") torch._check(A.dtype == torch.int32, lambda: "A must be int32")
return torch.empty_like(A, dtype=dtype) return torch.empty_like(A, dtype=dtype or torch.float16)
torch.library.define( torch.library.define(
......
...@@ -28,7 +28,7 @@ def _( ...@@ -28,7 +28,7 @@ def _(
A: torch.Tensor, A: torch.Tensor,
row_stats: torch.Tensor, row_stats: torch.Tensor,
col_stats: torch.Tensor, col_stats: torch.Tensor,
dtype=torch.float16, dtype: Optional[torch.dtype] = None,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}")
...@@ -43,7 +43,7 @@ def _( ...@@ -43,7 +43,7 @@ def _(
if bias is not None: if bias is not None:
out += bias out += bias
return out.to(dtype) return out.to(dtype or torch.float16)
@register_kernel("bitsandbytes::quantize_blockwise", "cpu") @register_kernel("bitsandbytes::quantize_blockwise", "cpu")
......
...@@ -90,7 +90,7 @@ def _( ...@@ -90,7 +90,7 @@ def _(
A: torch.Tensor, A: torch.Tensor,
row_stats: torch.Tensor, row_stats: torch.Tensor,
col_stats: torch.Tensor, col_stats: torch.Tensor,
dtype=torch.float16, dtype: Optional[torch.dtype] = None,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}")
...@@ -121,7 +121,7 @@ def _( ...@@ -121,7 +121,7 @@ def _(
if bias is not None and bias.dtype != torch.float16: if bias is not None and bias.dtype != torch.float16:
out.add_(bias) out.add_(bias)
return out.to(dtype) return out.to(dtype or torch.float16)
@register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda") @register_kernel("bitsandbytes::int8_vectorwise_quant", "cuda")
......
...@@ -5,26 +5,31 @@ import torch ...@@ -5,26 +5,31 @@ import torch
from ..._ops import register_kernel from ..._ops import register_kernel
@register_kernel("bitsandbytes::int8_scaled_mm", None) @register_kernel("bitsandbytes::int8_scaled_mm", "default")
def _( def _(
A: torch.Tensor, A: torch.Tensor,
B: torch.Tensor, B: torch.Tensor,
row_stats: torch.Tensor, row_stats: torch.Tensor,
col_stats: torch.Tensor, col_stats: torch.Tensor,
bias: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None,
dtype=torch.float16, dtype: Optional[torch.dtype] = None,
) -> torch.Tensor: ) -> torch.Tensor:
out_i32 = torch.ops.bitsandbytes.int8_linear_matmul.default(A, B) out_i32 = torch.ops.bitsandbytes.int8_linear_matmul.default(A, B)
out = torch.ops.bitsandbytes.int8_mm_dequant.default(out_i32, row_stats, col_stats, dtype=dtype, bias=bias) return torch.ops.bitsandbytes.int8_mm_dequant.default(
return out out_i32,
row_stats,
col_stats,
dtype=dtype or torch.float16,
bias=bias,
)
@register_kernel("bitsandbytes::int8_linear_matmul", None) @register_kernel("bitsandbytes::int8_linear_matmul", "default")
def _(A: torch.Tensor, B: torch.Tensor): def _(A: torch.Tensor, B: torch.Tensor):
return _int8_linear_matmul_impl(A, B) return _int8_linear_matmul_impl(A, B)
@register_kernel("bitsandbytes::int8_linear_matmul.out", None) @register_kernel("bitsandbytes::int8_linear_matmul.out", "default")
def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): def _(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor):
torch._check(out.dtype == torch.int32) torch._check(out.dtype == torch.int32)
_int8_linear_matmul_impl(A, B, out) _int8_linear_matmul_impl(A, B, out)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment