Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
3d869726
Unverified
Commit
3d869726
authored
Jul 11, 2023
by
Zehan Li
Committed by
GitHub
Jul 11, 2023
Browse files
add gradient checkpointing for distilbert (#24719)
* add gradient checkpointing for distilbert * reformatted
parent
2642d8d0
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
3 deletions
+28
-3
src/transformers/models/distilbert/modeling_distilbert.py
src/transformers/models/distilbert/modeling_distilbert.py
+28
-3
No files found.
src/transformers/models/distilbert/modeling_distilbert.py
View file @
3d869726
...
...
@@ -324,6 +324,7 @@ class Transformer(nn.Module):
super
().
__init__
()
self
.
n_layers
=
config
.
n_layers
self
.
layer
=
nn
.
ModuleList
([
TransformerBlock
(
config
)
for
_
in
range
(
config
.
n_layers
)])
self
.
gradient_checkpointing
=
False
def
forward
(
self
,
...
...
@@ -356,9 +357,28 @@ class Transformer(nn.Module):
if
output_hidden_states
:
all_hidden_states
=
all_hidden_states
+
(
hidden_state
,)
layer_outputs
=
layer_module
(
x
=
hidden_state
,
attn_mask
=
attn_mask
,
head_mask
=
head_mask
[
i
],
output_attentions
=
output_attentions
)
if
self
.
gradient_checkpointing
and
self
.
training
:
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
return
module
(
*
inputs
,
output_attentions
)
return
custom_forward
layer_outputs
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
layer_module
),
hidden_state
,
attn_mask
,
head_mask
[
i
],
)
else
:
layer_outputs
=
layer_module
(
hidden_state
,
attn_mask
,
head_mask
[
i
],
output_attentions
,
)
hidden_state
=
layer_outputs
[
-
1
]
if
output_attentions
:
...
...
@@ -392,6 +412,7 @@ class DistilBertPreTrainedModel(PreTrainedModel):
config_class
=
DistilBertConfig
load_tf_weights
=
None
base_model_prefix
=
"distilbert"
supports_gradient_checkpointing
=
True
def
_init_weights
(
self
,
module
:
nn
.
Module
):
"""Initialize the weights."""
...
...
@@ -409,6 +430,10 @@ class DistilBertPreTrainedModel(PreTrainedModel):
module
.
bias
.
data
.
zero_
()
module
.
weight
.
data
.
fill_
(
1.0
)
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
isinstance
(
module
,
Transformer
):
module
.
gradient_checkpointing
=
value
DISTILBERT_START_DOCSTRING
=
r
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment