Unverified Commit 812ef06a authored by Ruheena Suhani Shaik's avatar Ruheena Suhani Shaik Committed by GitHub
Browse files

Add support for Intel Gaudi/HPU backend (#1662)



* supports hpu backend in main branch

* Update bitsandbytes/backends/hpu/ops.py

updates the assertion message
Co-authored-by: default avatarMatthew Douglas <38992547+matthewdouglas@users.noreply.github.com>

* Update bitsandbytes/backends/hpu/ops.py
Co-authored-by: default avatarMatthew Douglas <38992547+matthewdouglas@users.noreply.github.com>

* Update ops.py

Fix lint issue

* Update ops.py

---------
Co-authored-by: default avatarMatthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
parent e9fc96a2
...@@ -26,7 +26,7 @@ supported_torch_devices = { ...@@ -26,7 +26,7 @@ supported_torch_devices = {
"cpu", "cpu",
"cuda", # NVIDIA/AMD GPU "cuda", # NVIDIA/AMD GPU
"xpu", # Intel GPU "xpu", # Intel GPU
"hpu", # Gaudi "hpu", # Intel Gaudi
"npu", # Ascend NPU "npu", # Ascend NPU
"mps", # Apple Silicon "mps", # Apple Silicon
} }
...@@ -37,6 +37,9 @@ if torch.cuda.is_available(): ...@@ -37,6 +37,9 @@ if torch.cuda.is_available():
if hasattr(torch, "xpu") and torch.xpu.is_available(): if hasattr(torch, "xpu") and torch.xpu.is_available():
from .backends.xpu import ops as xpu_ops from .backends.xpu import ops as xpu_ops
if hasattr(torch, "hpu") and torch.hpu.is_available():
from .backends.hpu import ops as hpu_ops
def _import_backends(): def _import_backends():
""" """
......
...@@ -451,7 +451,7 @@ def matmul_4bit( ...@@ -451,7 +451,7 @@ def matmul_4bit(
else: else:
return MatMul4Bit.apply(A, B, out, bias, quant_state) return MatMul4Bit.apply(A, B, out, bias, quant_state)
if A.numel() == A.shape[-1] and A.requires_grad == False: if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu":
if A.shape[-1] % quant_state.blocksize != 0: if A.shape[-1] % quant_state.blocksize != 0:
warn( warn(
f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}", f"Some matrices hidden dimension is not a multiple of {quant_state.blocksize} and efficient inference kernels are not supported for these (slow). Matrix input size found: {A.shape}",
......
from collections.abc import Sequence
import math
import torch
from bitsandbytes.utils import _reverse_4bit_compress_format
from ..._ops import register_kernel
from ..utils import GAUDI_SW_VER
@register_kernel("bitsandbytes::dequantize_4bit", "hpu")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
torch._check_is_size(blocksize)
torch._check(quant_type == "nf4", lambda: f"quant_type must be nf4, got {quant_type}")
torch._check(
A.dtype in [torch.bfloat16, torch.uint8],
lambda: f"quant_storage supports uint8 or bfloat16, but got {A.dtype}",
)
# Enable non uint8 dtype
if A.dtype != torch.uint8:
A = A.view(torch.uint8)
transpose = False if len(A.shape) == 2 and A.shape[0] == 1 else True
A = A.reshape(-1)
if GAUDI_SW_VER and (GAUDI_SW_VER.major < 1 or GAUDI_SW_VER.minor < 22):
A = _reverse_4bit_compress_format(A)
# HPU dequantization function for NF4 quantized tensors.
out_dq = torch.ops.hpu.dequantize_nf4(
A,
absmax.to(dtype),
blocksize,
out_shape=(math.prod(shape),),
out_dtype=dtype,
)
output = out_dq.reshape(shape)
if transpose:
output = output.t()
return output
import subprocess
from packaging import version
import torch import torch
try: try:
...@@ -59,3 +62,23 @@ _FP4_QUANT_TABLE = torch.tensor( ...@@ -59,3 +62,23 @@ _FP4_QUANT_TABLE = torch.tensor(
else "cpu", # Only cpu/xpu use this table for now. else "cpu", # Only cpu/xpu use this table for now.
) )
CODE = {"nf4": _NF4_QUANT_TABLE, "fp4": _FP4_QUANT_TABLE} CODE = {"nf4": _NF4_QUANT_TABLE, "fp4": _FP4_QUANT_TABLE}
def get_gaudi_sw_version():
"""
Returns the installed version of Gaudi SW.
"""
output = subprocess.run(
"pip list | grep habana-torch-plugin",
shell=True,
text=True,
capture_output=True,
)
# If grep return nothing
if not output.stdout.strip():
return None
return version.parse(output.stdout.split("\n")[0].split()[-1])
GAUDI_SW_VER = get_gaudi_sw_version()
...@@ -443,7 +443,7 @@ class Linear4bit(nn.Linear): ...@@ -443,7 +443,7 @@ class Linear4bit(nn.Linear):
) )
# self.persistent_buffers = [] # TODO consider as way to save quant state # self.persistent_buffers = [] # TODO consider as way to save quant state
self.compute_dtype = compute_dtype self.compute_dtype = compute_dtype
self.compute_type_is_set = False self.compute_type_is_set = False if compute_dtype is None else True
self.quant_state = None self.quant_state = None
self.quant_storage = quant_storage self.quant_storage = quant_storage
self.ipex_linear_is_set = False self.ipex_linear_is_set = False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment