Unverified Commit 89a1f342 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`RWKV`] Add Gradient Checkpointing support for RWKV (#24955)

add GC support for RWKV
parent 9f912ef6
......@@ -406,6 +406,7 @@ class RwkvPreTrainedModel(PreTrainedModel):
base_model_prefix = "rwkv"
_no_split_modules = ["RwkvBlock"]
_keep_in_fp32_modules = ["time_decay", "time_first"]
supports_gradient_checkpointing = True
def _init_weights(self, module):
"""Initialize the weights."""
......@@ -605,6 +606,8 @@ class RwkvModel(RwkvPreTrainedModel):
self.layers_are_rescaled = False
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
......@@ -659,14 +662,35 @@ class RwkvModel(RwkvPreTrainedModel):
]
state[4] -= 1e30
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
hidden_states = inputs_embeds
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for idx, block in enumerate(self.blocks):
if self.gradient_checkpointing and self.training:
def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
return custom_forward
hidden_states, state, attentions = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), hidden_states, state
)
else:
hidden_states, state, attentions = block(
hidden_states, state=state, use_cache=use_cache, output_attentions=output_attentions
)
if (
self.layers_are_rescaled
and self.config.rescale_every > 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment