Unverified Commit 6d714a5c authored by Vladimir Malinovskii's avatar Vladimir Malinovskii Committed by GitHub
Browse files

Embedding4bit and Embedding8bit implementation (#1292)



* Embedding4bit and Embedding8bit implementation

* lint

* Update bitsandbytes/nn/modules.py
Co-authored-by: default avatarMatthew Douglas <38992547+matthewdouglas@users.noreply.github.com>

* Update bitsandbytes/nn/modules.py
Co-authored-by: default avatarMatthew Douglas <38992547+matthewdouglas@users.noreply.github.com>

* Update bitsandbytes/nn/modules.py
Co-authored-by: default avatarMatthew Douglas <38992547+matthewdouglas@users.noreply.github.com>

* saving -> Saving

---------
Co-authored-by: default avatarMatthew Douglas <38992547+matthewdouglas@users.noreply.github.com>
parent 4be18838
......@@ -4,6 +4,10 @@
# LICENSE file in the root directory of this source tree.
from .modules import (
Embedding,
Embedding4bit,
Embedding8bit,
EmbeddingFP4,
EmbeddingNF4,
Int8Params,
Linear4bit,
Linear8bitLt,
......
......@@ -347,6 +347,23 @@ class Params4bit(torch.nn.Parameter):
return new_param
def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Linear4bit"]):
if getattr(module.weight, "quant_state", None) is not None:
return
if getattr(module, "quant_state", None) is None:
warnings.warn(
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.",
)
# the quant state got lost when the parameter got converted. This happens for example for fsdp
# since we registered the module, we can recover the state here
assert module.weight.shape[1] == 1
if not isinstance(module.weight, Params4bit):
module.weight = Params4bit(module.weight, quant_storage=module.quant_storage, bnb_quantized=True)
module.weight.quant_state = module.quant_state
class Linear4bit(nn.Linear):
"""
This class is the base module for the 4-bit quantization algorithm presented in [QLoRA](https://arxiv.org/abs/2305.14314).
......@@ -449,22 +466,12 @@ class Linear4bit(nn.Linear):
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
def forward(self, x: torch.Tensor):
fix_4bit_weight_quant_state_from_module(self)
# weights are cast automatically as Int8Params, but the bias has to be cast manually
if self.bias is not None and self.bias.dtype != x.dtype:
self.bias.data = self.bias.data.to(x.dtype)
if getattr(self.weight, "quant_state", None) is None:
if getattr(self, "quant_state", None) is not None:
# the quant state got lost when the parameter got converted. This happens for example for fsdp
# since we registered the module, we can recover the state here
assert self.weight.shape[1] == 1
if not isinstance(self.weight, Params4bit):
self.weight = Params4bit(self.weight, quant_storage=self.quant_storage, bnb_quantized=True)
self.weight.quant_state = self.quant_state
else:
print(
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.",
)
if not self.compute_type_is_set:
self.set_compute_type(x)
self.compute_type_is_set = True
......@@ -658,6 +665,191 @@ def maybe_rearrange_weight(state_dict, prefix, local_metadata, strict, missing_k
state_dict[f"{prefix}weight"] = undo_layout(weight, tile_indices)
class Embedding8bit(nn.Embedding):
"""
This class implements [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm for embedding layer
Quantization API is similar to Linear8bitLt:
```python
import torch
import torch.nn as nn
from bitsandbytes.nn import Embedding8bit
fp16_module = nn.Embedding(128, 64)
int8_module = Embedding8bit(128, 64)
int8_module.load_state_dict(fp16_module.state_dict())
int8_module = int8_module.to(0) # Quantization happens here
```
"""
def __init__(self, num_embeddings, embedding_dim, device=None, dtype=None):
super().__init__(num_embeddings, embedding_dim, device=device, dtype=dtype)
self.dtype = self.weight.data.dtype
self.weight = Int8Params(self.weight.data, has_fp16_weights=False, requires_grad=False)
def _save_to_state_dict(self, destination, prefix, keep_vars):
raise NotImplementedError("Saving Embedding8bit module is not implemented")
def forward(self, input: Tensor) -> Tensor:
if not hasattr(self.weight, "SCB"):
raise RuntimeError("Embedding layer is not quantized. Please call .cuda() or .to(device) first.")
rows = self.weight.data
row_stats = self.weight.SCB
assert rows.shape == (self.num_embeddings, self.embedding_dim)
assert row_stats.shape == (self.num_embeddings,)
compressed_output = F.embedding(input, rows)
compressed_output_stats = F.embedding(input, row_stats.view(self.num_embeddings, 1))
output = compressed_output * (compressed_output_stats / 127.0)
return output.to(self.dtype)
class Embedding4bit(nn.Embedding):
"""
This is the base class similar to Linear4bit. It implements the 4-bit quantization algorithm presented in
[QLoRA](https://arxiv.org/abs/2305.14314) for embeddings.
Quantization API is similar to Linear4bit:
```python
import torch
import torch.nn as nn
from bitsandbytes.nn import Embedding4bit
fp16_module = nn.Embedding(128, 64)
quantized_module = Embedding4bit(128, 64)
quantized_module.load_state_dict(fp16_module.state_dict())
quantized_module = quantized_module.to(0) # Quantization happens here
```
"""
def __init__(
self,
num_embeddings,
embedding_dim,
dtype=None,
quant_type="fp4",
quant_storage=torch.uint8,
device=None,
):
super().__init__(num_embeddings, embedding_dim, device=device, dtype=dtype)
self.dtype = self.weight.data.dtype
self.weight = Params4bit(
self.weight.data,
requires_grad=False,
compress_statistics=None,
quant_type=quant_type,
quant_storage=quant_storage,
module=self,
)
blocksize = self.weight.blocksize
if embedding_dim % blocksize != 0:
warnings.warn(
f"Embedding size {embedding_dim} is not divisible by block size {blocksize}. "
"This will lead to slow inference.",
)
def _forward_with_partial_dequantize(self, input: Tensor):
assert self.embedding_dim % self.weight.quant_state.blocksize == 0
w_4bit_uint8 = self.weight.data.view(torch.uint8).view(self.num_embeddings * self.embedding_dim // 2, 1)
output_4bit = torch.nn.functional.embedding(
weight=w_4bit_uint8.view(self.num_embeddings, self.embedding_dim // 2),
input=input,
).view(-1, 1)
assert output_4bit.shape == (input.numel() * self.embedding_dim // 2, 1)
blocks_per_emb = self.embedding_dim // self.weight.blocksize
absmax = self.weight.quant_state.absmax
assert absmax.shape == (self.num_embeddings * blocks_per_emb,)
output_absmax = torch.nn.functional.embedding(
weight=absmax.view(self.num_embeddings, blocks_per_emb),
input=input,
).view(
-1,
)
assert output_absmax.shape == (input.numel() * blocks_per_emb,)
output_quant_state = copy.deepcopy(self.weight.quant_state)
output_quant_state.absmax = output_absmax
output_quant_state.shape = torch.Size((*input.shape, self.embedding_dim))
output = bnb.functional.dequantize_4bit(output_4bit, output_quant_state)
assert output.shape == (*input.shape, self.embedding_dim)
return output.to(self.dtype)
def _save_to_state_dict(self, destination, prefix, keep_vars):
raise NotImplementedError("Saving Embedding4bit module is not implemented")
def forward(self, input: Tensor) -> Tensor:
fix_4bit_weight_quant_state_from_module(self)
if self.embedding_dim % self.weight.quant_state.blocksize == 0:
return self._forward_with_partial_dequantize(input)
dequantized_weight = bnb.functional.dequantize_4bit(self.weight.data, self.weight.quant_state)
return torch.nn.functional.embedding(
weight=dequantized_weight,
input=input,
).to(self.dtype)
class EmbeddingFP4(Embedding4bit):
def __init__(
self,
num_embeddings,
embedding_dim,
dtype=None,
quant_storage=torch.uint8,
device=None,
):
super().__init__(
num_embeddings,
embedding_dim,
dtype=dtype,
quant_type="fp4",
quant_storage=quant_storage,
device=device,
)
class EmbeddingNF4(Embedding4bit):
def __init__(
self,
num_embeddings,
embedding_dim,
dtype=None,
quant_storage=torch.uint8,
device=None,
):
super().__init__(
num_embeddings,
embedding_dim,
dtype=dtype,
quant_type="nf4",
quant_storage=quant_storage,
device=device,
)
class Linear8bitLt(nn.Linear):
"""
This class is the base module for the [LLM.int8()](https://arxiv.org/abs/2208.07339) algorithm.
......
import inspect
import math
import einops
......@@ -616,7 +617,97 @@ def test_fp8linear():
assert bgraderr < 0.00002
def test_4bit_warnings(requires_cuda):
@pytest.mark.parametrize("embedding_dim", [64, 65])
@pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str)
@pytest.mark.parametrize(
"embedding_class,quant_storage",
[
(bnb.nn.Embedding8bit, None),
(bnb.nn.EmbeddingFP4, torch.uint8),
(bnb.nn.EmbeddingFP4, torch.float32),
(bnb.nn.EmbeddingNF4, torch.uint8),
(bnb.nn.EmbeddingNF4, torch.float32),
],
ids=lambda x: x.__name__ if inspect.isclass(x) else str(x),
)
def test_embedding_lossless(embedding_class, input_shape, embedding_dim, quant_storage):
num_embeddings = 128
src_weight = (torch.randn((num_embeddings, embedding_dim), dtype=torch.float32) > 0).to(
torch.float32
) * 2 - 1 # Embeddings filled with {-1, 1} values. It should compress losslessly
emb_base = nn.Embedding(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
_freeze=True,
_weight=src_weight,
)
if embedding_class is bnb.nn.Embedding8bit:
e = embedding_class(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
else:
e = embedding_class(num_embeddings=num_embeddings, embedding_dim=embedding_dim, quant_storage=quant_storage)
e.load_state_dict(emb_base.state_dict())
emb_base.cuda()
e.cuda()
input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device="cuda")
torch.testing.assert_close(
actual=e(input_tokens),
expected=emb_base(input_tokens),
)
@pytest.mark.parametrize("embedding_dim", [64, 65])
@pytest.mark.parametrize("input_shape", [(10,), (10, 10), (10, 10, 10)], ids=str)
@pytest.mark.parametrize(
"embedding_class,quant_storage",
[
(bnb.nn.Embedding8bit, None),
(bnb.nn.EmbeddingFP4, torch.uint8),
(bnb.nn.EmbeddingFP4, torch.float32),
(bnb.nn.EmbeddingNF4, torch.uint8),
(bnb.nn.EmbeddingNF4, torch.float32),
],
ids=lambda x: x.__name__ if inspect.isclass(x) else str(x),
)
def test_embedding_error(embedding_class, input_shape, embedding_dim, quant_storage):
is_8bit = embedding_class is bnb.nn.Embedding8bit
num_embeddings = 128
src_weight = torch.rand((num_embeddings, embedding_dim), dtype=torch.float32)
emb_base = nn.Embedding(
num_embeddings=num_embeddings,
embedding_dim=embedding_dim,
_freeze=True,
_weight=src_weight,
)
if is_8bit:
e = embedding_class(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
else:
e = embedding_class(num_embeddings=num_embeddings, embedding_dim=embedding_dim, quant_storage=quant_storage)
e.load_state_dict(emb_base.state_dict())
emb_base.cuda()
e.cuda()
input_tokens = torch.randint(low=0, high=num_embeddings, size=input_shape, device="cuda")
torch.testing.assert_close(
actual=e(input_tokens),
expected=emb_base(input_tokens),
atol=0.05 if is_8bit else 0.20,
rtol=0.0,
)
def test_4bit_linear_warnings():
dim1 = 64
with pytest.warns(UserWarning, match=r"inference or training"):
......@@ -642,3 +733,58 @@ def test_4bit_warnings(requires_cuda):
net(inp)
assert len(record) == 2
def test_4bit_embedding_warnings():
num_embeddings = 128
default_block_size = 64
with pytest.warns(UserWarning, match=r"inference."):
net = bnb.nn.Embedding4bit(num_embeddings=num_embeddings, embedding_dim=default_block_size + 1)
net.cuda()
inp = torch.randint(low=0, high=num_embeddings, size=(1,), device="cuda")
net(inp)
def test_4bit_embedding_weight_fsdp_fix():
num_embeddings = 64
embedding_dim = 32
module = bnb.nn.Embedding4bit(num_embeddings=num_embeddings, embedding_dim=embedding_dim)
module.cuda()
module.weight.quant_state = None
input_tokens = torch.randint(low=0, high=num_embeddings, size=(1,), device="cuda")
module(input_tokens)
assert module.weight.quant_state is not None
def test_4bit_linear_weight_fsdp_fix():
inp_size = 64
out_size = 32
module = bnb.nn.Linear4bit(inp_size, out_size)
module.cuda()
module.weight.quant_state = None
input_tensor = torch.randn((1, inp_size), device="cuda")
module(input_tensor)
assert module.weight.quant_state is not None
def test_embedding_not_implemented_error():
with pytest.raises(NotImplementedError):
emb = bnb.nn.Embedding4bit(32, 32)
emb.state_dict()
with pytest.raises(NotImplementedError):
emb = bnb.nn.Embedding8bit(32, 32)
emb.state_dict()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment