Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
5347d000
Unverified
Commit
5347d000
authored
Aug 17, 2023
by
Arthur
Committed by
GitHub
Aug 17, 2023
Browse files
[`SwitchTransformers`] Remove unused module (#25427)
* remove unused module * remove old feed_forward_proj * fixup
parent
d6bf08f7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
2 additions
and
46 deletions
+2
-46
src/transformers/models/switch_transformers/configuration_switch_transformers.py
.../switch_transformers/configuration_switch_transformers.py
+2
-17
src/transformers/models/switch_transformers/modeling_switch_transformers.py
...odels/switch_transformers/modeling_switch_transformers.py
+0
-29
No files found.
src/transformers/models/switch_transformers/configuration_switch_transformers.py
View file @
5347d000
...
...
@@ -122,7 +122,7 @@ class SwitchTransformersConfig(PretrainedConfig):
router_z_loss_coef
=
0.001
,
router_aux_loss_coef
=
0.001
,
initializer_factor
=
1.0
,
feed_forward_proj
=
"relu"
,
dense_act_fn
=
"relu"
,
is_encoder_decoder
=
True
,
add_router_probs
=
False
,
use_cache
=
True
,
...
...
@@ -171,27 +171,12 @@ class SwitchTransformersConfig(PretrainedConfig):
self
.
dropout_rate
=
dropout_rate
self
.
layer_norm_epsilon
=
layer_norm_epsilon
self
.
initializer_factor
=
initializer_factor
self
.
feed_forward_proj
=
feed_forward_proj
self
.
use_cache
=
use_cache
self
.
add_router_probs
=
add_router_probs
self
.
router_z_loss_coef
=
router_z_loss_coef
self
.
router_aux_loss_coef
=
router_aux_loss_coef
act_info
=
self
.
feed_forward_proj
.
split
(
"-"
)
self
.
dense_act_fn
=
act_info
[
-
1
]
self
.
is_gated_act
=
act_info
[
0
]
==
"gated"
if
len
(
act_info
)
>
1
and
act_info
[
0
]
!=
"gated"
or
len
(
act_info
)
>
2
:
raise
ValueError
(
f
"`feed_forward_proj`:
{
feed_forward_proj
}
is not a valid activation function of the dense layer."
"Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
"'gated-gelu' or 'relu'"
)
# for backwards compatibility
if
feed_forward_proj
==
"gated-gelu"
:
self
.
dense_act_fn
=
"gelu_new"
self
.
dense_act_fn
=
dense_act_fn
super
().
__init__
(
pad_token_id
=
pad_token_id
,
...
...
src/transformers/models/switch_transformers/modeling_switch_transformers.py
View file @
5347d000
...
...
@@ -282,25 +282,6 @@ class SwitchTransformersDenseActDense(nn.Module):
return
hidden_states
# Copied from transformers.models.longt5.modeling_longt5.LongT5DenseGatedActDense with LongT5->SwitchTransformers
class
SwitchTransformersDenseGatedActDense
(
nn
.
Module
):
def
__init__
(
self
,
config
:
SwitchTransformersConfig
):
super
().
__init__
()
self
.
wi_0
=
nn
.
Linear
(
config
.
d_model
,
config
.
d_ff
,
bias
=
False
)
self
.
wi_1
=
nn
.
Linear
(
config
.
d_model
,
config
.
d_ff
,
bias
=
False
)
self
.
wo
=
nn
.
Linear
(
config
.
d_ff
,
config
.
d_model
,
bias
=
False
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout_rate
)
self
.
act
=
ACT2FN
[
config
.
dense_act_fn
]
def
forward
(
self
,
hidden_states
):
hidden_gelu
=
self
.
act
(
self
.
wi_0
(
hidden_states
))
hidden_linear
=
self
.
wi_1
(
hidden_states
)
hidden_states
=
hidden_gelu
*
hidden_linear
hidden_states
=
self
.
dropout
(
hidden_states
)
hidden_states
=
self
.
wo
(
hidden_states
)
return
hidden_states
class
SwitchTransformersSparseMLP
(
nn
.
Module
):
r
"""
Implementation of the Switch Transformers Sparse MLP module.
...
...
@@ -861,16 +842,6 @@ class SwitchTransformersPreTrainedModel(PreTrainedModel):
module
.
wo
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
factor
*
((
self
.
config
.
d_ff
)
**
-
0.5
))
if
hasattr
(
module
.
wo
,
"bias"
)
and
module
.
wo
.
bias
is
not
None
:
module
.
wo
.
bias
.
data
.
zero_
()
elif
isinstance
(
module
,
SwitchTransformersDenseGatedActDense
):
module
.
wi_0
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
factor
*
((
self
.
config
.
d_model
)
**
-
0.5
))
if
hasattr
(
module
.
wi_0
,
"bias"
)
and
module
.
wi_0
.
bias
is
not
None
:
module
.
wi_0
.
bias
.
data
.
zero_
()
module
.
wi_1
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
factor
*
((
self
.
config
.
d_model
)
**
-
0.5
))
if
hasattr
(
module
.
wi_1
,
"bias"
)
and
module
.
wi_1
.
bias
is
not
None
:
module
.
wi_1
.
bias
.
data
.
zero_
()
module
.
wo
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
factor
*
((
self
.
config
.
d_ff
)
**
-
0.5
))
if
hasattr
(
module
.
wo
,
"bias"
)
and
module
.
wo
.
bias
is
not
None
:
module
.
wo
.
bias
.
data
.
zero_
()
elif
isinstance
(
module
,
SwitchTransformersAttention
):
# Mesh TensorFlow attention initialization to avoid scaling before softmax
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment