Unverified Commit 27d916b3 authored by Egor's avatar Egor Committed by GitHub
Browse files

Moved int8_mm_dequant to default backend (#1626)

parent 8b858e4e
from collections.abc import Sequence
import ctypes as ct
from typing import Optional
import torch
......@@ -24,29 +23,6 @@ if torch.__version__ >= (2, 6):
).reshape(*A.shape[:-1], B.shape[0])
@register_kernel("bitsandbytes::int8_mm_dequant", "cpu")
def _(
A: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
dtype: Optional[torch.dtype] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}")
torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}")
torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}")
A_calc = A.view(-1, A.shape[-1])
row_stats = row_stats.reshape(-1).unsqueeze(-1)
col_stats = col_stats.reshape(-1).unsqueeze(0)
out = A_calc * (row_stats * col_stats) * 6.200124e-05
if bias is not None:
out += bias
return out.to(dtype or torch.float16)
@register_kernel("bitsandbytes::quantize_blockwise", "cpu")
def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
......
......@@ -6,6 +6,29 @@ import torch
from ..._ops import register_kernel
@register_kernel("bitsandbytes::int8_mm_dequant", "default")
def _(
A: torch.Tensor,
row_stats: torch.Tensor,
col_stats: torch.Tensor,
dtype: Optional[torch.dtype] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}")
torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}")
torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}")
A_calc = A.view(-1, A.shape[-1])
row_stats = row_stats.reshape(-1).unsqueeze(-1)
col_stats = col_stats.reshape(-1).unsqueeze(0)
out = A_calc * (row_stats * col_stats) * 6.200124e-05
if bias is not None:
out += bias
return out.to(dtype or torch.float16)
@register_kernel("bitsandbytes::int8_mixed_scaled_mm", "default")
def _(
A: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment