Unverified Commit 21bd3be1 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`RWKV`] Rwkv fix for 8bit inference (#23468)

* rwkv fix for 8bit inference

* add comment
parent 1c460a52
...@@ -709,8 +709,13 @@ class RwkvModel(RwkvPreTrainedModel): ...@@ -709,8 +709,13 @@ class RwkvModel(RwkvPreTrainedModel):
block.attention.output.weight.mul_(2 ** int(block_id // self.config.rescale_every)) block.attention.output.weight.mul_(2 ** int(block_id // self.config.rescale_every))
block.feed_forward.value.weight.mul_(2 ** int(block_id // self.config.rescale_every)) block.feed_forward.value.weight.mul_(2 ** int(block_id // self.config.rescale_every))
else: else:
block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every)) # Deal with quantization statistics
block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every)) if hasattr(block.attention.output.weight, "SCB"):
block.attention.output.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
block.feed_forward.value.weight.SCB.div_(2 ** int(block_id // self.config.rescale_every))
else:
block.attention.output.weight.div_(2 ** int(block_id // self.config.rescale_every))
block.feed_forward.value.weight.div_(2 ** int(block_id // self.config.rescale_every))
self.layers_are_rescaled = not self.training self.layers_are_rescaled = not self.training
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment