"...git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "f49a3453caa6fe606bb31c571423f72264152fce"
Unverified Commit 5347d000 authored by Arthur's avatar Arthur Committed by GitHub
Browse files

[`SwitchTransformers`] Remove unused module (#25427)

* remove unused module

* remove old feed_forward_proj

* fixup
parent d6bf08f7
...@@ -122,7 +122,7 @@ class SwitchTransformersConfig(PretrainedConfig): ...@@ -122,7 +122,7 @@ class SwitchTransformersConfig(PretrainedConfig):
router_z_loss_coef=0.001, router_z_loss_coef=0.001,
router_aux_loss_coef=0.001, router_aux_loss_coef=0.001,
initializer_factor=1.0, initializer_factor=1.0,
feed_forward_proj="relu", dense_act_fn="relu",
is_encoder_decoder=True, is_encoder_decoder=True,
add_router_probs=False, add_router_probs=False,
use_cache=True, use_cache=True,
...@@ -171,27 +171,12 @@ class SwitchTransformersConfig(PretrainedConfig): ...@@ -171,27 +171,12 @@ class SwitchTransformersConfig(PretrainedConfig):
self.dropout_rate = dropout_rate self.dropout_rate = dropout_rate
self.layer_norm_epsilon = layer_norm_epsilon self.layer_norm_epsilon = layer_norm_epsilon
self.initializer_factor = initializer_factor self.initializer_factor = initializer_factor
self.feed_forward_proj = feed_forward_proj
self.use_cache = use_cache self.use_cache = use_cache
self.add_router_probs = add_router_probs self.add_router_probs = add_router_probs
self.router_z_loss_coef = router_z_loss_coef self.router_z_loss_coef = router_z_loss_coef
self.router_aux_loss_coef = router_aux_loss_coef self.router_aux_loss_coef = router_aux_loss_coef
self.dense_act_fn = dense_act_fn
act_info = self.feed_forward_proj.split("-")
self.dense_act_fn = act_info[-1]
self.is_gated_act = act_info[0] == "gated"
if len(act_info) > 1 and act_info[0] != "gated" or len(act_info) > 2:
raise ValueError(
f"`feed_forward_proj`: {feed_forward_proj} is not a valid activation function of the dense layer."
"Please make sure `feed_forward_proj` is of the format `gated-{ACT_FN}` or `{ACT_FN}`, e.g. "
"'gated-gelu' or 'relu'"
)
# for backwards compatibility
if feed_forward_proj == "gated-gelu":
self.dense_act_fn = "gelu_new"
super().__init__( super().__init__(
pad_token_id=pad_token_id, pad_token_id=pad_token_id,
......
...@@ -282,25 +282,6 @@ class SwitchTransformersDenseActDense(nn.Module): ...@@ -282,25 +282,6 @@ class SwitchTransformersDenseActDense(nn.Module):
return hidden_states return hidden_states
# Copied from transformers.models.longt5.modeling_longt5.LongT5DenseGatedActDense with LongT5->SwitchTransformers
class SwitchTransformersDenseGatedActDense(nn.Module):
def __init__(self, config: SwitchTransformersConfig):
super().__init__()
self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
self.dropout = nn.Dropout(config.dropout_rate)
self.act = ACT2FN[config.dense_act_fn]
def forward(self, hidden_states):
hidden_gelu = self.act(self.wi_0(hidden_states))
hidden_linear = self.wi_1(hidden_states)
hidden_states = hidden_gelu * hidden_linear
hidden_states = self.dropout(hidden_states)
hidden_states = self.wo(hidden_states)
return hidden_states
class SwitchTransformersSparseMLP(nn.Module): class SwitchTransformersSparseMLP(nn.Module):
r""" r"""
Implementation of the Switch Transformers Sparse MLP module. Implementation of the Switch Transformers Sparse MLP module.
...@@ -861,16 +842,6 @@ class SwitchTransformersPreTrainedModel(PreTrainedModel): ...@@ -861,16 +842,6 @@ class SwitchTransformersPreTrainedModel(PreTrainedModel):
module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5)) module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
if hasattr(module.wo, "bias") and module.wo.bias is not None: if hasattr(module.wo, "bias") and module.wo.bias is not None:
module.wo.bias.data.zero_() module.wo.bias.data.zero_()
elif isinstance(module, SwitchTransformersDenseGatedActDense):
module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
module.wi_0.bias.data.zero_()
module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
module.wi_1.bias.data.zero_()
module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
if hasattr(module.wo, "bias") and module.wo.bias is not None:
module.wo.bias.data.zero_()
elif isinstance(module, SwitchTransformersAttention): elif isinstance(module, SwitchTransformersAttention):
# Mesh TensorFlow attention initialization to avoid scaling before softmax # Mesh TensorFlow attention initialization to avoid scaling before softmax
# See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136 # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment