# Copyright 2020 Huy Le Nguyen (@usimarit) # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import tensorflow as tf class GLU(tf.keras.layers.Layer): def __init__(self, axis=-1, name="glu_activation", **kwargs): super(GLU, self).__init__(name=name, **kwargs) self.axis = axis def call(self, inputs, **kwargs): a, b = tf.split(inputs, 2, axis=self.axis) b = tf.nn.sigmoid(b) return tf.multiply(a, b) def get_config(self): conf = super(GLU, self).get_config() conf.update({"axis": self.axis}) return conf