"git@developer.sourcefind.cn:hehl2/torchaudio.git" did not exist on "575d221ec4b394fcc8e5bb8147d0e7ba012afe56"
Commit d64db6df authored by lukovnikov's avatar lukovnikov
Browse files

clean up pr

parent 7ba83730
...@@ -25,10 +25,7 @@ import six ...@@ -25,10 +25,7 @@ import six
import torch import torch
import torch.nn as nn import torch.nn as nn
from torch.nn import CrossEntropyLoss from torch.nn import CrossEntropyLoss
from six import string_types
ACT2FN = {"gelu": gelu, "relu": torch.nn.ReLU, "swish": swish}
def gelu(x): def gelu(x):
"""Implementation of the gelu activation function. """Implementation of the gelu activation function.
...@@ -42,6 +39,9 @@ def swish(x): ...@@ -42,6 +39,9 @@ def swish(x):
return x * torch.sigmoid(x) return x * torch.sigmoid(x)
ACT2FN = {"gelu": gelu, "relu": torch.nn.functional.relu, "swish": swish}
class BertConfig(object): class BertConfig(object):
"""Configuration class to store the configuration of a `BertModel`. """Configuration class to store the configuration of a `BertModel`.
""" """
...@@ -68,7 +68,7 @@ class BertConfig(object): ...@@ -68,7 +68,7 @@ class BertConfig(object):
intermediate_size: The size of the "intermediate" (i.e., feed-forward) intermediate_size: The size of the "intermediate" (i.e., feed-forward)
layer in the Transformer encoder. layer in the Transformer encoder.
hidden_act: The non-linear activation function (function or string) in the hidden_act: The non-linear activation function (function or string) in the
encoder and pooler. If string, "gelu", "relu" and "swish" supported. encoder and pooler. If string, "gelu", "relu" and "swish" are supported.
hidden_dropout_prob: The dropout probabilitiy for all fully connected hidden_dropout_prob: The dropout probabilitiy for all fully connected
layers in the embeddings, encoder, and pooler. layers in the embeddings, encoder, and pooler.
attention_probs_dropout_prob: The dropout ratio for the attention attention_probs_dropout_prob: The dropout ratio for the attention
...@@ -246,7 +246,7 @@ class BERTIntermediate(nn.Module): ...@@ -246,7 +246,7 @@ class BERTIntermediate(nn.Module):
super(BERTIntermediate, self).__init__() super(BERTIntermediate, self).__init__()
self.dense = nn.Linear(config.hidden_size, config.intermediate_size) self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
self.intermediate_act_fn = ACT2FN[config.hidden_act] \ self.intermediate_act_fn = ACT2FN[config.hidden_act] \
if isinstance(config.hidden_act, str) else config.hidden_act if isinstance(config.hidden_act, string_types) else config.hidden_act
def forward(self, hidden_states): def forward(self, hidden_states):
hidden_states = self.dense(hidden_states) hidden_states = self.dense(hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment