Unverified Commit 89a1f342 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`RWKV`] Add Gradient Checkpointing support for RWKV (#24955)

add GC support for RWKV
parent 9f912ef6
...@@ -406,6 +406,7 @@ class RwkvPreTrainedModel(PreTrainedModel): ...@@ -406,6 +406,7 @@ class RwkvPreTrainedModel(PreTrainedModel):
base_model_prefix = "rwkv" base_model_prefix = "rwkv"
_no_split_modules = ["RwkvBlock"] _no_split_modules = ["RwkvBlock"]
_keep_in_fp32_modules = ["time_decay", "time_first"] _keep_in_fp32_modules = ["time_decay", "time_first"]
supports_gradient_checkpointing = True
def _init_weights(self, module): def _init_weights(self, module):
"""Initialize the weights.""" """Initialize the weights."""
...@@ -605,6 +606,8 @@ class RwkvModel(RwkvPreTrainedModel): ...@@ -605,6 +606,8 @@ class RwkvModel(RwkvPreTrainedModel):
self.layers_are_rescaled = False self.layers_are_rescaled = False
self.gradient_checkpointing = False
# Initialize weights and apply final processing # Initialize weights and apply final processing
self.post_init() self.post_init()
...@@ -659,14 +662,35 @@ class RwkvModel(RwkvPreTrainedModel): ...@@ -659,14 +662,35 @@ class RwkvModel(RwkvPreTrainedModel):
] ]
state[4] -= 1e30 state[4] -= 1e30
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning_once(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
hidden_states = inputs_embeds hidden_states = inputs_embeds
all_self_attentions = () if output_attentions else None all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None all_hidden_states = () if output_hidden_states else None
for idx, block in enumerate(self.blocks): for idx, block in enumerate(self.blocks):
hidden_states, state, attentions = block( if self.gradient_checkpointing and self.training:
hidden_states, state=state, use_cache=use_cache, output_attentions=output_attentions
) def create_custom_forward(module):
def custom_forward(*inputs):
# None for past_key_value
return module(*inputs, use_cache=use_cache, output_attentions=output_attentions)
return custom_forward
hidden_states, state, attentions = torch.utils.checkpoint.checkpoint(
create_custom_forward(block), hidden_states, state
)
else:
hidden_states, state, attentions = block(
hidden_states, state=state, use_cache=use_cache, output_attentions=output_attentions
)
if ( if (
self.layers_are_rescaled self.layers_are_rescaled
and self.config.rescale_every > 0 and self.config.rescale_every > 0
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment