Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1383c7b8
Commit
1383c7b8
authored
Jul 23, 2019
by
thomwolf
Browse files
Fix #869
parent
6070b554
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
18 additions
and
4 deletions
+18
-4
pytorch_transformers/modeling_utils.py
pytorch_transformers/modeling_utils.py
+18
-4
No files found.
pytorch_transformers/modeling_utils.py
View file @
1383c7b8
...
@@ -39,6 +39,20 @@ WEIGHTS_NAME = "pytorch_model.bin"
...
@@ -39,6 +39,20 @@ WEIGHTS_NAME = "pytorch_model.bin"
TF_WEIGHTS_NAME
=
'model.ckpt'
TF_WEIGHTS_NAME
=
'model.ckpt'
try
:
from
torch.nn
import
Identity
except
ImportError
:
# Older PyTorch compatibility
class
Identity
(
nn
.
Module
):
r
"""A placeholder identity operator that is argument-insensitive.
"""
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
(
Identity
,
self
).
__init__
()
def
forward
(
self
,
input
):
return
input
if
not
six
.
PY2
:
if
not
six
.
PY2
:
def
add_start_docstrings
(
*
docstr
):
def
add_start_docstrings
(
*
docstr
):
def
docstring_decorator
(
fn
):
def
docstring_decorator
(
fn
):
...
@@ -731,7 +745,7 @@ class SequenceSummary(nn.Module):
...
@@ -731,7 +745,7 @@ class SequenceSummary(nn.Module):
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
# We can probably just use the multi-head attention module of PyTorch >=1.1.0
raise
NotImplementedError
raise
NotImplementedError
self
.
summary
=
nn
.
Identity
()
self
.
summary
=
Identity
()
if
hasattr
(
config
,
'summary_use_proj'
)
and
config
.
summary_use_proj
:
if
hasattr
(
config
,
'summary_use_proj'
)
and
config
.
summary_use_proj
:
if
hasattr
(
config
,
'summary_proj_to_labels'
)
and
config
.
summary_proj_to_labels
and
config
.
num_labels
>
0
:
if
hasattr
(
config
,
'summary_proj_to_labels'
)
and
config
.
summary_proj_to_labels
and
config
.
num_labels
>
0
:
num_classes
=
config
.
num_labels
num_classes
=
config
.
num_labels
...
@@ -739,15 +753,15 @@ class SequenceSummary(nn.Module):
...
@@ -739,15 +753,15 @@ class SequenceSummary(nn.Module):
num_classes
=
config
.
hidden_size
num_classes
=
config
.
hidden_size
self
.
summary
=
nn
.
Linear
(
config
.
hidden_size
,
num_classes
)
self
.
summary
=
nn
.
Linear
(
config
.
hidden_size
,
num_classes
)
self
.
activation
=
nn
.
Identity
()
self
.
activation
=
Identity
()
if
hasattr
(
config
,
'summary_activation'
)
and
config
.
summary_activation
==
'tanh'
:
if
hasattr
(
config
,
'summary_activation'
)
and
config
.
summary_activation
==
'tanh'
:
self
.
activation
=
nn
.
Tanh
()
self
.
activation
=
nn
.
Tanh
()
self
.
first_dropout
=
nn
.
Identity
()
self
.
first_dropout
=
Identity
()
if
hasattr
(
config
,
'summary_first_dropout'
)
and
config
.
summary_first_dropout
>
0
:
if
hasattr
(
config
,
'summary_first_dropout'
)
and
config
.
summary_first_dropout
>
0
:
self
.
first_dropout
=
nn
.
Dropout
(
config
.
summary_first_dropout
)
self
.
first_dropout
=
nn
.
Dropout
(
config
.
summary_first_dropout
)
self
.
last_dropout
=
nn
.
Identity
()
self
.
last_dropout
=
Identity
()
if
hasattr
(
config
,
'summary_last_dropout'
)
and
config
.
summary_last_dropout
>
0
:
if
hasattr
(
config
,
'summary_last_dropout'
)
and
config
.
summary_last_dropout
>
0
:
self
.
last_dropout
=
nn
.
Dropout
(
config
.
summary_last_dropout
)
self
.
last_dropout
=
nn
.
Dropout
(
config
.
summary_last_dropout
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment