Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
d64db6df
Commit
d64db6df
authored
Nov 13, 2018
by
lukovnikov
Browse files
clean up pr
parent
7ba83730
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
6 additions
and
6 deletions
+6
-6
modeling.py
modeling.py
+6
-6
No files found.
modeling.py
View file @
d64db6df
...
...
@@ -25,10 +25,7 @@ import six
import
torch
import
torch.nn
as
nn
from
torch.nn
import
CrossEntropyLoss
ACT2FN
=
{
"gelu"
:
gelu
,
"relu"
:
torch
.
nn
.
ReLU
,
"swish"
:
swish
}
from
six
import
string_types
def
gelu
(
x
):
"""Implementation of the gelu activation function.
...
...
@@ -42,6 +39,9 @@ def swish(x):
return
x
*
torch
.
sigmoid
(
x
)
ACT2FN
=
{
"gelu"
:
gelu
,
"relu"
:
torch
.
nn
.
functional
.
relu
,
"swish"
:
swish
}
class
BertConfig
(
object
):
"""Configuration class to store the configuration of a `BertModel`.
"""
...
...
@@ -68,7 +68,7 @@ class BertConfig(object):
intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" supported.
encoder and pooler. If string, "gelu", "relu" and "swish"
are
supported.
hidden_dropout_prob: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention
...
...
@@ -246,7 +246,7 @@ class BERTIntermediate(nn.Module):
super
(
BERTIntermediate
,
self
).
__init__
()
self
.
dense
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
intermediate_size
)
self
.
intermediate_act_fn
=
ACT2FN
[
config
.
hidden_act
]
\
if
isinstance
(
config
.
hidden_act
,
str
)
else
config
.
hidden_act
if
isinstance
(
config
.
hidden_act
,
str
ing_types
)
else
config
.
hidden_act
def
forward
(
self
,
hidden_states
):
hidden_states
=
self
.
dense
(
hidden_states
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment