Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
e2964b8a
Unverified
Commit
e2964b8a
authored
Sep 22, 2020
by
Stas Bekman
Committed by
GitHub
Sep 22, 2020
Browse files
[fsmt] no need to pass device (#7292)
parent
e4b94d8e
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
5 additions
and
4 deletions
+5
-4
src/transformers/modeling_fsmt.py
src/transformers/modeling_fsmt.py
+5
-4
No files found.
src/transformers/modeling_fsmt.py
View file @
e2964b8a
...
@@ -1150,13 +1150,14 @@ class SinusoidalPositionalEmbedding(nn.Embedding):
...
@@ -1150,13 +1150,14 @@ class SinusoidalPositionalEmbedding(nn.Embedding):
def
__init__
(
self
,
num_positions
,
embedding_dim
,
padding_idx
):
def
__init__
(
self
,
num_positions
,
embedding_dim
,
padding_idx
):
self
.
make_weight
(
num_positions
,
embedding_dim
,
padding_idx
)
self
.
make_weight
(
num_positions
,
embedding_dim
,
padding_idx
)
def
make_weight
(
self
,
num_positions
,
embedding_dim
,
padding_idx
,
device
=
None
):
def
make_weight
(
self
,
num_positions
,
embedding_dim
,
padding_idx
):
weight
=
self
.
get_embedding
(
num_positions
,
embedding_dim
,
padding_idx
)
weight
=
self
.
get_embedding
(
num_positions
,
embedding_dim
,
padding_idx
)
if
device
is
not
None
:
weight
=
weight
.
to
(
device
)
if
not
hasattr
(
self
,
"weight"
):
if
not
hasattr
(
self
,
"weight"
):
# in ___init__
super
().
__init__
(
num_positions
,
embedding_dim
,
padding_idx
,
_weight
=
weight
)
super
().
__init__
(
num_positions
,
embedding_dim
,
padding_idx
,
_weight
=
weight
)
else
:
else
:
# in forward
weight
=
weight
.
to
(
self
.
weight
.
device
)
self
.
weight
=
nn
.
Parameter
(
weight
)
self
.
weight
=
nn
.
Parameter
(
weight
)
self
.
weight
.
detach_
()
self
.
weight
.
detach_
()
self
.
weight
.
requires_grad
=
False
self
.
weight
.
requires_grad
=
False
...
@@ -1204,6 +1205,6 @@ class SinusoidalPositionalEmbedding(nn.Embedding):
...
@@ -1204,6 +1205,6 @@ class SinusoidalPositionalEmbedding(nn.Embedding):
max_pos
=
self
.
padding_idx
+
1
+
seq_len
max_pos
=
self
.
padding_idx
+
1
+
seq_len
if
max_pos
>
self
.
weight
.
size
(
0
):
if
max_pos
>
self
.
weight
.
size
(
0
):
# expand embeddings if needed
# expand embeddings if needed
self
.
make_weight
(
max_pos
,
self
.
embedding_dim
,
self
.
padding_idx
,
device
=
input
.
device
)
self
.
make_weight
(
max_pos
,
self
.
embedding_dim
,
self
.
padding_idx
)
positions
=
self
.
make_positions
(
input
,
self
.
padding_idx
)
positions
=
self
.
make_positions
(
input
,
self
.
padding_idx
)
return
super
().
forward
(
positions
)
return
super
().
forward
(
positions
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment