Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
2f2aa0c8
Unverified
Commit
2f2aa0c8
authored
Aug 06, 2020
by
Teven
Committed by
GitHub
Aug 06, 2020
Browse files
added `n_inner` argument to gpt2 config (#6296)
parent
0a0d53dc
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
8 additions
and
2 deletions
+8
-2
src/transformers/configuration_gpt2.py
src/transformers/configuration_gpt2.py
+4
-0
src/transformers/modeling_gpt2.py
src/transformers/modeling_gpt2.py
+2
-1
src/transformers/modeling_tf_gpt2.py
src/transformers/modeling_tf_gpt2.py
+2
-1
No files found.
src/transformers/configuration_gpt2.py
View file @
2f2aa0c8
...
...
@@ -59,6 +59,8 @@ class GPT2Config(PretrainedConfig):
Number of hidden layers in the Transformer encoder.
n_head (:obj:`int`, optional, defaults to 12):
Number of attention heads for each attention layer in the Transformer encoder.
n_inner (:obj:`int`, optional, defaults to None):
Dimensionality of the inner feed-forward layers. :obj:`None` will set it to 4 times n_embd
activation_function (:obj:`str`, optional, defaults to 'gelu'):
Activation function selected in the list ["relu", "swish", "gelu", "tanh", "gelu_new"].
resid_pdrop (:obj:`float`, optional, defaults to 0.1):
...
...
@@ -122,6 +124,7 @@ class GPT2Config(PretrainedConfig):
n_embd
=
768
,
n_layer
=
12
,
n_head
=
12
,
n_inner
=
None
,
activation_function
=
"gelu_new"
,
resid_pdrop
=
0.1
,
embd_pdrop
=
0.1
,
...
...
@@ -145,6 +148,7 @@ class GPT2Config(PretrainedConfig):
self
.
n_embd
=
n_embd
self
.
n_layer
=
n_layer
self
.
n_head
=
n_head
self
.
n_inner
=
n_inner
self
.
activation_function
=
activation_function
self
.
resid_pdrop
=
resid_pdrop
self
.
embd_pdrop
=
embd_pdrop
...
...
src/transformers/modeling_gpt2.py
View file @
2f2aa0c8
...
...
@@ -240,10 +240,11 @@ class Block(nn.Module):
def
__init__
(
self
,
n_ctx
,
config
,
scale
=
False
):
super
().
__init__
()
nx
=
config
.
n_embd
inner_dim
=
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
nx
self
.
ln_1
=
nn
.
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
attn
=
Attention
(
nx
,
n_ctx
,
config
,
scale
)
self
.
ln_2
=
nn
.
LayerNorm
(
nx
,
eps
=
config
.
layer_norm_epsilon
)
self
.
mlp
=
MLP
(
4
*
nx
,
config
)
self
.
mlp
=
MLP
(
inner_dim
,
config
)
def
forward
(
self
,
x
,
layer_past
=
None
,
attention_mask
=
None
,
head_mask
=
None
,
use_cache
=
False
,
output_attentions
=
False
,
...
...
src/transformers/modeling_tf_gpt2.py
View file @
2f2aa0c8
...
...
@@ -194,10 +194,11 @@ class TFBlock(tf.keras.layers.Layer):
def
__init__
(
self
,
n_ctx
,
config
,
scale
=
False
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
nx
=
config
.
n_embd
inner_dim
=
config
.
n_inner
if
config
.
n_inner
is
not
None
else
4
*
nx
self
.
ln_1
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
config
.
layer_norm_epsilon
,
name
=
"ln_1"
)
self
.
attn
=
TFAttention
(
nx
,
n_ctx
,
config
,
scale
,
name
=
"attn"
)
self
.
ln_2
=
tf
.
keras
.
layers
.
LayerNormalization
(
epsilon
=
config
.
layer_norm_epsilon
,
name
=
"ln_2"
)
self
.
mlp
=
TFMLP
(
4
*
nx
,
config
,
name
=
"mlp"
)
self
.
mlp
=
TFMLP
(
inner_dim
,
config
,
name
=
"mlp"
)
def
call
(
self
,
x
,
layer_past
,
attention_mask
,
head_mask
,
use_cache
,
output_attentions
,
training
=
False
):
a
=
self
.
ln_1
(
x
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment