Unverified Commit f3558bbc authored by Xa9aX ツ's avatar Xa9aX ツ Committed by GitHub
Browse files

Depreciate pythonic Mish and support PyTorch 1.9 version of Mish (#12240)

* Moved Mish to Torch 1.9 version

* Run black formatting
parent 47a97683
...@@ -73,10 +73,20 @@ else: ...@@ -73,10 +73,20 @@ else:
silu = nn.functional.silu silu = nn.functional.silu
def mish(x): def _mish_python(x):
"""
See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also
visit the official repository for the paper: https://github.com/digantamisra98/Mish
"""
return x * torch.tanh(nn.functional.softplus(x)) return x * torch.tanh(nn.functional.softplus(x))
if version.parse(torch.__version__) < version.parse("1.9"):
mish = _mish_python
else:
mish = nn.functional.mish
def linear_act(x): def linear_act(x):
return x return x
......
...@@ -140,10 +140,6 @@ def load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path): ...@@ -140,10 +140,6 @@ def load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path):
return model return model
def mish(x):
return x * torch.tanh(nn.functional.softplus(x))
class NoNorm(nn.Module): class NoNorm(nn.Module):
def __init__(self, feat_size, eps=None): def __init__(self, feat_size, eps=None):
super().__init__() super().__init__()
......
...@@ -138,10 +138,6 @@ def load_tf_weights_in_{{cookiecutter.lowercase_modelname}}(model, config, tf_ch ...@@ -138,10 +138,6 @@ def load_tf_weights_in_{{cookiecutter.lowercase_modelname}}(model, config, tf_ch
return model return model
def mish(x):
return x * torch.tanh(nn.functional.softplus(x))
# Copied from transformers.models.bert.modeling_bert.BertEmbeddings with Bert->{{cookiecutter.camelcase_modelname}} # Copied from transformers.models.bert.modeling_bert.BertEmbeddings with Bert->{{cookiecutter.camelcase_modelname}}
class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module): class {{cookiecutter.camelcase_modelname}}Embeddings(nn.Module):
"""Construct the embeddings from word, position and token_type embeddings.""" """Construct the embeddings from word, position and token_type embeddings."""
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment