Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
421929b5
Unverified
Commit
421929b5
authored
Sep 16, 2021
by
Patrick von Platen
Committed by
GitHub
Sep 16, 2021
Browse files
finish (#13593)
parent
b5bab710
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
21 additions
and
30 deletions
+21
-30
src/transformers/models/distilbert/modeling_distilbert.py
src/transformers/models/distilbert/modeling_distilbert.py
+19
-30
src/transformers/models/pegasus/modeling_pegasus.py
src/transformers/models/pegasus/modeling_pegasus.py
+2
-0
No files found.
src/transformers/models/distilbert/modeling_distilbert.py
View file @
421929b5
...
@@ -73,6 +73,17 @@ DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
...
@@ -73,6 +73,17 @@ DISTILBERT_PRETRAINED_MODEL_ARCHIVE_LIST = [
def
create_sinusoidal_embeddings
(
n_pos
,
dim
,
out
):
def
create_sinusoidal_embeddings
(
n_pos
,
dim
,
out
):
if
is_deepspeed_zero3_enabled
():
import
deepspeed
with
deepspeed
.
zero
.
GatheredParameters
(
out
,
modifier_rank
=
0
):
if
torch
.
distributed
.
get_rank
()
==
0
:
_create_sinusoidal_embeddings
(
n_pos
=
n_pos
,
dim
=
dim
,
out
=
out
)
else
:
_create_sinusoidal_embeddings
(
n_pos
=
n_pos
,
dim
=
dim
,
out
=
out
)
def
_create_sinusoidal_embeddings
(
n_pos
,
dim
,
out
):
position_enc
=
np
.
array
([[
pos
/
np
.
power
(
10000
,
2
*
(
j
//
2
)
/
dim
)
for
j
in
range
(
dim
)]
for
pos
in
range
(
n_pos
)])
position_enc
=
np
.
array
([[
pos
/
np
.
power
(
10000
,
2
*
(
j
//
2
)
/
dim
)
for
j
in
range
(
dim
)]
for
pos
in
range
(
n_pos
)])
out
.
requires_grad
=
False
out
.
requires_grad
=
False
out
[:,
0
::
2
]
=
torch
.
FloatTensor
(
np
.
sin
(
position_enc
[:,
0
::
2
]))
out
[:,
0
::
2
]
=
torch
.
FloatTensor
(
np
.
sin
(
position_enc
[:,
0
::
2
]))
...
@@ -86,19 +97,9 @@ class Embeddings(nn.Module):
...
@@ -86,19 +97,9 @@ class Embeddings(nn.Module):
self
.
word_embeddings
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
dim
,
padding_idx
=
config
.
pad_token_id
)
self
.
word_embeddings
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
dim
,
padding_idx
=
config
.
pad_token_id
)
self
.
position_embeddings
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
config
.
dim
)
self
.
position_embeddings
=
nn
.
Embedding
(
config
.
max_position_embeddings
,
config
.
dim
)
if
config
.
sinusoidal_pos_embds
:
if
config
.
sinusoidal_pos_embds
:
create_sinusoidal_embeddings
(
if
is_deepspeed_zero3_enabled
():
n_pos
=
config
.
max_position_embeddings
,
dim
=
config
.
dim
,
out
=
self
.
position_embeddings
.
weight
import
deepspeed
)
with
deepspeed
.
zero
.
GatheredParameters
(
self
.
position_embeddings
.
weight
,
modifier_rank
=
0
):
if
torch
.
distributed
.
get_rank
()
==
0
:
create_sinusoidal_embeddings
(
n_pos
=
config
.
max_position_embeddings
,
dim
=
config
.
dim
,
out
=
self
.
position_embeddings
.
weight
)
else
:
create_sinusoidal_embeddings
(
n_pos
=
config
.
max_position_embeddings
,
dim
=
config
.
dim
,
out
=
self
.
position_embeddings
.
weight
)
self
.
LayerNorm
=
nn
.
LayerNorm
(
config
.
dim
,
eps
=
1e-12
)
self
.
LayerNorm
=
nn
.
LayerNorm
(
config
.
dim
,
eps
=
1e-12
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
...
@@ -475,23 +476,9 @@ class DistilBertModel(DistilBertPreTrainedModel):
...
@@ -475,23 +476,9 @@ class DistilBertModel(DistilBertPreTrainedModel):
self
.
embeddings
.
position_embeddings
=
nn
.
Embedding
(
self
.
config
.
max_position_embeddings
,
self
.
config
.
dim
)
self
.
embeddings
.
position_embeddings
=
nn
.
Embedding
(
self
.
config
.
max_position_embeddings
,
self
.
config
.
dim
)
if
self
.
config
.
sinusoidal_pos_embds
:
if
self
.
config
.
sinusoidal_pos_embds
:
create_sinusoidal_embeddings
(
if
is_deepspeed_zero3_enabled
():
n_pos
=
self
.
config
.
max_position_embeddings
,
dim
=
self
.
config
.
dim
,
out
=
self
.
position_embeddings
.
weight
import
deepspeed
)
with
deepspeed
.
zero
.
GatheredParameters
(
self
.
embeddings
.
position_embeddings
.
weight
,
modifier_rank
=
0
):
if
torch
.
distributed
.
get_rank
()
==
0
:
create_sinusoidal_embeddings
(
n_pos
=
self
.
config
.
max_position_embeddings
,
dim
=
self
.
config
.
dim
,
out
=
self
.
embeddings
.
position_embeddings
.
weight
,
)
else
:
create_sinusoidal_embeddings
(
n_pos
=
self
.
config
.
max_position_embeddings
,
dim
=
self
.
config
.
dim
,
out
=
self
.
embeddings
.
position_embeddings
.
weight
,
)
else
:
else
:
with
torch
.
no_grad
():
with
torch
.
no_grad
():
if
num_position_embeds_diff
>
0
:
if
num_position_embeds_diff
>
0
:
...
@@ -502,6 +489,8 @@ class DistilBertModel(DistilBertPreTrainedModel):
...
@@ -502,6 +489,8 @@ class DistilBertModel(DistilBertPreTrainedModel):
self
.
embeddings
.
position_embeddings
.
weight
=
nn
.
Parameter
(
self
.
embeddings
.
position_embeddings
.
weight
=
nn
.
Parameter
(
old_position_embeddings_weight
[:
num_position_embeds_diff
]
old_position_embeddings_weight
[:
num_position_embeds_diff
]
)
)
# move position_embeddings to correct device
self
.
embeddings
.
position_embeddings
.
to
(
self
.
device
)
def
get_input_embeddings
(
self
):
def
get_input_embeddings
(
self
):
return
self
.
embeddings
.
word_embeddings
return
self
.
embeddings
.
word_embeddings
...
...
src/transformers/models/pegasus/modeling_pegasus.py
View file @
421929b5
...
@@ -668,6 +668,7 @@ class PegasusEncoder(PegasusPreTrainedModel):
...
@@ -668,6 +668,7 @@ class PegasusEncoder(PegasusPreTrainedModel):
self
.
config
.
d_model
,
self
.
config
.
d_model
,
self
.
padding_idx
,
self
.
padding_idx
,
)
)
self
.
embed_positions
.
to
(
self
.
device
)
def
get_position_embeddings
(
self
)
->
nn
.
Embedding
:
def
get_position_embeddings
(
self
)
->
nn
.
Embedding
:
"""
"""
...
@@ -886,6 +887,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
...
@@ -886,6 +887,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
self
.
config
.
d_model
,
self
.
config
.
d_model
,
self
.
padding_idx
,
self
.
padding_idx
,
)
)
self
.
embed_positions
.
to
(
self
.
device
)
def
get_position_embeddings
(
self
)
->
nn
.
Embedding
:
def
get_position_embeddings
(
self
)
->
nn
.
Embedding
:
"""
"""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment