Unverified Commit 2adcb7a7 authored by Vivek Goel's avatar Vivek Goel Committed by GitHub
Browse files

Add function to reverse 4bit weights for HPU (#1757)

* Add function to reverse 4bit weights for HPU

* Fix lint error
parent b2a8a156
...@@ -3,12 +3,19 @@ import math ...@@ -3,12 +3,19 @@ import math
import torch import torch
from bitsandbytes.utils import _reverse_4bit_compress_format
from ..._ops import register_kernel from ..._ops import register_kernel
from ..utils import GAUDI_SW_VER from ..utils import GAUDI_SW_VER
# convert btw standard 4-bit compression format and ipex compression format
# needed for backward compatibility with older versions of gaudi sw
def _reverse_4bit_compress_format(weight: torch.Tensor):
out_1 = (weight & 0xF0) >> 4
out_2 = (weight & 0xF) << 4
out = out_1 | out_2
return out
@register_kernel("bitsandbytes::dequantize_4bit", "hpu") @register_kernel("bitsandbytes::dequantize_4bit", "hpu")
def _( def _(
A: torch.Tensor, A: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment