Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
89a1f342
Unverified
Commit
89a1f342
authored
Jul 20, 2023
by
Younes Belkada
Committed by
GitHub
Jul 20, 2023
Browse files
[`RWKV`] Add Gradient Checkpointing support for RWKV (#24955)
add GC support for RWKV
parent
9f912ef6
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
27 additions
and
3 deletions
+27
-3
src/transformers/models/rwkv/modeling_rwkv.py
src/transformers/models/rwkv/modeling_rwkv.py
+27
-3
No files found.
src/transformers/models/rwkv/modeling_rwkv.py
View file @
89a1f342
...
@@ -406,6 +406,7 @@ class RwkvPreTrainedModel(PreTrainedModel):
...
@@ -406,6 +406,7 @@ class RwkvPreTrainedModel(PreTrainedModel):
base_model_prefix
=
"rwkv"
base_model_prefix
=
"rwkv"
_no_split_modules
=
[
"RwkvBlock"
]
_no_split_modules
=
[
"RwkvBlock"
]
_keep_in_fp32_modules
=
[
"time_decay"
,
"time_first"
]
_keep_in_fp32_modules
=
[
"time_decay"
,
"time_first"
]
supports_gradient_checkpointing
=
True
def
_init_weights
(
self
,
module
):
def
_init_weights
(
self
,
module
):
"""Initialize the weights."""
"""Initialize the weights."""
...
@@ -605,6 +606,8 @@ class RwkvModel(RwkvPreTrainedModel):
...
@@ -605,6 +606,8 @@ class RwkvModel(RwkvPreTrainedModel):
self
.
layers_are_rescaled
=
False
self
.
layers_are_rescaled
=
False
self
.
gradient_checkpointing
=
False
# Initialize weights and apply final processing
# Initialize weights and apply final processing
self
.
post_init
()
self
.
post_init
()
...
@@ -659,14 +662,35 @@ class RwkvModel(RwkvPreTrainedModel):
...
@@ -659,14 +662,35 @@ class RwkvModel(RwkvPreTrainedModel):
]
]
state
[
4
]
-=
1e30
state
[
4
]
-=
1e30
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache
=
False
hidden_states
=
inputs_embeds
hidden_states
=
inputs_embeds
all_self_attentions
=
()
if
output_attentions
else
None
all_self_attentions
=
()
if
output_attentions
else
None
all_hidden_states
=
()
if
output_hidden_states
else
None
all_hidden_states
=
()
if
output_hidden_states
else
None
for
idx
,
block
in
enumerate
(
self
.
blocks
):
for
idx
,
block
in
enumerate
(
self
.
blocks
):
hidden_states
,
state
,
attentions
=
block
(
if
self
.
gradient_checkpointing
and
self
.
training
:
hidden_states
,
state
=
state
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
)
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
# None for past_key_value
return
module
(
*
inputs
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
)
return
custom_forward
hidden_states
,
state
,
attentions
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
block
),
hidden_states
,
state
)
else
:
hidden_states
,
state
,
attentions
=
block
(
hidden_states
,
state
=
state
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
)
if
(
if
(
self
.
layers_are_rescaled
self
.
layers_are_rescaled
and
self
.
config
.
rescale_every
>
0
and
self
.
config
.
rescale_every
>
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment