Unverified Commit 7ccac73f authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`RWKV`] Final fix RWMV 4bit (#26134)

* Final fix RWMV 4bit

* fixup

* add a test

* add more clarifications
parent 32ec7345
...@@ -31,6 +31,7 @@ from ...utils import ( ...@@ -31,6 +31,7 @@ from ...utils import (
add_code_sample_docstrings, add_code_sample_docstrings,
add_start_docstrings, add_start_docstrings,
add_start_docstrings_to_model_forward, add_start_docstrings_to_model_forward,
is_bitsandbytes_available,
is_ninja_available, is_ninja_available,
is_torch_cuda_available, is_torch_cuda_available,
logging, logging,
...@@ -735,18 +736,35 @@ class RwkvModel(RwkvPreTrainedModel): ...@@ -735,18 +736,35 @@ class RwkvModel(RwkvPreTrainedModel):
block.attention.output.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every)) block.attention.output.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
block.feed_forward.value.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every)) block.feed_forward.value.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
elif hasattr(block.attention.output.weight, "quant_state"): elif hasattr(block.attention.output.weight, "quant_state"):
block.attention.output.weight.quant_state[0].div_( self._bnb_4bit_dequantize_and_rescale(block.attention.output, block_id)
2 ** int(block_id // self.config.rescale_every) self._bnb_4bit_dequantize_and_rescale(block.feed_forward.value, block_id)
)
block.feed_forward.value.weight.quant_state[0].div_(
2 ** int(block_id // self.config.rescale_every)
)
else: else:
block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every)) block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every))
block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every)) block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every))
self.layers_are_rescaled = not self.training self.layers_are_rescaled = not self.training
def _bnb_4bit_dequantize_and_rescale(self, target_layer, block_id):
r"""
Perform the dequantization and rescaling of the weights of a given layer. After that operation the layer will
be quantized again.
"""
if not is_bitsandbytes_available():
raise ImportError("Please install bitsandbytes to use this method.")
import bitsandbytes as bnb
dequant_weights = bnb.functional.dequantize_4bit(target_layer.weight.data, target_layer.weight.quant_state)
dequant_weights.div_(2 ** int(block_id // self.config.rescale_every))
# re-quantize the model:
# we need to put it first on CPU then back to the device
# this will create an overhead :/
# We set requires_grad=False as we cannot compute gradients on top of 4bit parameters anyway and to avoid
# bugs with bnb
quant_weight = bnb.nn.Params4bit(dequant_weights.to("cpu"), requires_grad=False).to(dequant_weights.device)
setattr(target_layer, "weight", quant_weight)
@add_start_docstrings( @add_start_docstrings(
""" """
......
...@@ -172,6 +172,22 @@ class Bnb4BitTest(Base4bitTest): ...@@ -172,6 +172,22 @@ class Bnb4BitTest(Base4bitTest):
# 4-bit parameters are packed in uint8 variables # 4-bit parameters are packed in uint8 variables
self.assertTrue(module.weight.dtype == torch.uint8) self.assertTrue(module.weight.dtype == torch.uint8)
def test_rwkv_4bit(self):
r"""
A simple test to check if 4-bit RWKV inference works as expected.
"""
model_id = "RWKV/rwkv-4-169m-pile"
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True)
model = AutoModelForCausalLM.from_pretrained(model_id, quantization_config=quantization_config)
tok = AutoTokenizer.from_pretrained(model_id)
text = "Hello my name is"
input_ids = tok.encode(text, return_tensors="pt").to(0)
_ = model.generate(input_ids, max_new_tokens=30)
def test_generate_quality(self): def test_generate_quality(self):
r""" r"""
Test the generation quality of the quantized model and see that we are matching the expected output. Test the generation quality of the quantized model and see that we are matching the expected output.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment