Unverified Commit 7ccac73f authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`RWKV`] Final fix RWMV 4bit (#26134)

* Final fix RWMV 4bit

* fixup

* add a test

* add more clarifications
parent 32ec7345
......@@ -31,6 +31,7 @@ from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
add_start_docstrings_to_model_forward,
is_bitsandbytes_available,
is_ninja_available,
is_torch_cuda_available,
logging,
......@@ -735,18 +736,35 @@ class RwkvModel(RwkvPreTrainedModel):
block.attention.output.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
block.feed_forward.value.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
elif hasattr(block.attention.output.weight, "quant_state"):
block.attention.output.weight.quant_state[0].div_(
2 ** int(block_id // self.config.rescale_every)
)
block.feed_forward.value.weight.quant_state[0].div_(
2 ** int(block_id // self.config.rescale_every)
)
self._bnb_4bit_dequantize_and_rescale(block.attention.output, block_id)
self._bnb_4bit_dequantize_and_rescale(block.feed_forward.value, block_id)
else:
block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every))
block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every))
self.layers_are_rescaled = not self.training
def _bnb_4bit_dequantize_and_rescale(self, target_layer, block_id):
r"""
Perform the dequantization and rescaling of the weights of a given layer. After that operation the layer will
be quantized again.
"""
if not is_bitsandbytes_available():
raise ImportError("Please install bitsandbytes to use this method.")
import bitsandbytes as bnb
dequant_weights = bnb.functional.dequantize_4bit(target_layer.weight.data, target_layer.weight.quant_state)
dequant_weights.div_(2 ** int(block_id // self.config.rescale_every))
# re-quantize the model:
# we need to put it first on CPU then back to the device
# this will create an overhead :/
# We set requires_grad=False as we cannot compute gradients on top of 4bit parameters anyway and to avoid
# bugs with bnb
quant_weight = bnb.nn.Params4bit(dequant_weights.to("cpu"), requires_grad=False).to(dequant_weights.device)
setattr(target_layer, "weight", quant_weight)
@add_start_docstrings(
"""
......
......@@ -172,6 +172,22 @@ class Bnb4BitTest(Base4bitTest):
# 4-bit parameters are packed in uint8 variables
self.assertTrue(module.weight.dtype == torch.uint8)
def test_rwkv_4bit(self):
r"""
A simple test to check if 4-bit RWKV inference works as expected.
"""
model_id = "RWKV/rwkv-4-169m-pile"
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True)
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)
tok = AutoTokenizer.from_pretrained(model_id)
text = "Hello my name is"
input_ids = tok.encode(text, return_tensors="pt").to(0)
_ = model.generate(input_ids, max_new_tokens=30)
def test_generate_quality(self):
r"""
Test the generation quality of the quantized model and see that we are matching the expected output.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment