Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5347d000
Unverified
Commit
5347d000
authored
Aug 17, 2023
by
Arthur
Committed by
GitHub
Aug 17, 2023
Browse files
[`SwitchTransformers`] Remove unused module (#25427)
* remove unused module * remove old feed_forward_proj * fixup
parent
d6bf08f7
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
2 additions
and
46 deletions
+2
-46
src/transformers/models/switch_transformers/configuration_switch_transformers.py
.../switch_transformers/configuration_switch_transformers.py
+2
-17
src/transformers/models/switch_transformers/modeling_switch_transformers.py
...odels/switch_transformers/modeling_switch_transformers.py
+0
-29
No files found.
src/transformers/models/switch_transformers/configuration_switch_transformers.py
View file @
5347d000
...
...
@@ -122,7 +122,7 @@ class SwitchTransformersConfig(PretrainedConfig):
router_z_loss_coef
=
0.001
,
router_aux_loss_coef
=
0.001
,
initializer_factor
=
1.0
,
feed_forward_proj
=
"relu"
,
dense_act_fn
=
"relu"
,
is_encoder_decoder
=
True
,
add_router_probs
=
False
,
use_cache
=
True
,
...
...
@@ -171,27 +171,12 @@ class SwitchTransformersConfig(PretrainedConfig):
self
.
dropout_rate
=
dropout_rate
self
.
layer_norm_epsilon
=
layer_norm_epsilon
self
.
initializer_factor
=
initializer_factor
self
.
feed_forward_proj
=
feed_forward_proj
self
.
use_cache
=
use_cache
self
.
add_router_probs
=
add_router_probs
self
.
router_z_loss_coef
=
router_z_loss_coef
self
.
router_aux_loss_coef
=
router_aux_loss_coef
act_info
=
self
.
feed_forward_proj
.
split
(
"-"
)
self
.
dense_act_fn
=
act_info
[
-
1
]
self
.
is_gated_act
=
act_info
[
0
]
==
"gated"
if
len
(
act_info
)
>
1
and
act_info
[
0
]
!=
"gated"
or
len
(
act_info
)
>
2
:
raise
ValueError
(
f
"`feed_forward_proj`:
{
feed_forward_proj
}
is not a valid activation function of the dense layer."
"Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
"'gated-gelu' or 'relu'"
)
# for backwards compatibility
if
feed_forward_proj
==
"gated-gelu"
:
self
.
dense_act_fn
=
"gelu_new"
self
.
dense_act_fn
=
dense_act_fn
super
().
__init__
(
pad_token_id
=
pad_token_id
,
...
...
src/transformers/models/switch_transformers/modeling_switch_transformers.py
View file @
5347d000
...
...
@@ -282,25 +282,6 @@ class SwitchTransformersDenseActDense(nn.Module):
return
hidden_states
# Copied from transformers.models.longt5.modeling_longt5.LongT5DenseGatedActDense with LongT5->SwitchTransformers
class
SwitchTransformersDenseGatedActDense
(
nn
.
Module
):
def
__init__
(
self
,
config
:
SwitchTransformersConfig
):
super
().
__init__
()
self
.
wi_0
=
nn
.
Linear
(
config
.
d_model
,
config
.
d_ff
,
bias
=
False
)
self
.
wi_1
=
nn
.
Linear
(
config
.
d_model
,
config
.
d_ff
,
bias
=
False
)
self
.
wo
=
nn
.
Linear
(
config
.
d_ff
,
config
.
d_model
,
bias
=
False
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout_rate
)
self
.
act
=
ACT2FN
[
config
.
dense_act_fn
]
def
forward
(
self
,
hidden_states
):
hidden_gelu
=
self
.
act
(
self
.
wi_0
(
hidden_states
))
hidden_linear
=
self
.
wi_1
(
hidden_states
)
hidden_states
=
hidden_gelu
*
hidden_linear
hidden_states
=
self
.
dropout
(
hidden_states
)
hidden_states
=
self
.
wo
(
hidden_states
)
return
hidden_states
class
SwitchTransformersSparseMLP
(
nn
.
Module
):
r
"""
Implementation of the Switch Transformers Sparse MLP module.
...
...
@@ -861,16 +842,6 @@ class SwitchTransformersPreTrainedModel(PreTrainedModel):
module
.
wo
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
factor
*
((
self
.
config
.
d_ff
)
**
-
0.5
))
if
hasattr
(
module
.
wo
,
"bias"
)
and
module
.
wo
.
bias
is
not
None
:
module
.
wo
.
bias
.
data
.
zero_
()
elif
isinstance
(
module
,
SwitchTransformersDenseGatedActDense
):
module
.
wi_0
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
factor
*
((
self
.
config
.
d_model
)
**
-
0.5
))
if
hasattr
(
module
.
wi_0
,
"bias"
)
and
module
.
wi_0
.
bias
is
not
None
:
module
.
wi_0
.
bias
.
data
.
zero_
()
module
.
wi_1
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
factor
*
((
self
.
config
.
d_model
)
**
-
0.5
))
if
hasattr
(
module
.
wi_1
,
"bias"
)
and
module
.
wi_1
.
bias
is
not
None
:
module
.
wi_1
.
bias
.
data
.
zero_
()
module
.
wo
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
factor
*
((
self
.
config
.
d_ff
)
**
-
0.5
))
if
hasattr
(
module
.
wo
,
"bias"
)
and
module
.
wo
.
bias
is
not
None
:
module
.
wo
.
bias
.
data
.
zero_
()
elif
isinstance
(
module
,
SwitchTransformersAttention
):
# Mesh TensorFlow attention initialization to avoid scaling before softmax
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment