Unverified Commit 21bd3be1 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`RWKV`] Rwkv fix for 8bit inference (#23468)

* rwkv fix for 8bit inference

* add comment
parent 1c460a52
......@@ -708,6 +708,11 @@ class RwkvModel(RwkvPreTrainedModel):
if self.training:
block.attention.output.weight.mul_(2 ** int(block_id // self.config.rescale_every))
block.feed_forward.value.weight.mul_(2 ** int(block_id // self.config.rescale_every))
else:
# Deal with quantization statistics
if hasattr(block.attention.output.weight, "SCB"):
block.attention.output.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
block.feed_forward.value.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
else:
block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every))
block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment