Unverified Commit fc64559c authored by Julien Plu's avatar Julien Plu Committed by GitHub
Browse files

Fix TF CTRL model naming (#6134)

parent 641b873c
......@@ -70,14 +70,14 @@ extras["sklearn"] = ["scikit-learn"]
# keras2onnx and onnxconverter-common version is specific through a commit until 1.7.0 lands on pypi
extras["tf"] = [
"tensorflow<=2.2",
"tensorflow",
# "onnxconverter-common",
# "keras2onnx"
"onnxconverter-common @ git+git://github.com/microsoft/onnxconverter-common.git@f64ca15989b6dc95a1f3507ff6e4c395ba12dff5#egg=onnxconverter-common",
"keras2onnx @ git+git://github.com/onnx/keras-onnx.git@cbdc75cb950b16db7f0a67be96a278f8d2953b48#egg=keras2onnx"
]
extras["tf-cpu"] = [
"tensorflow-cpu<=2.2",
"tensorflow-cpu",
# "onnxconverter-common",
# "keras2onnx"
"onnxconverter-common @ git+git://github.com/microsoft/onnxconverter-common.git@f64ca15989b6dc95a1f3507ff6e4c395ba12dff5#egg=onnxconverter-common",
......@@ -86,7 +86,7 @@ extras["tf-cpu"] = [
extras["torch"] = ["torch"]
extras["serving"] = ["pydantic", "uvicorn", "fastapi", "starlette"]
extras["all"] = extras["serving"] + ["tensorflow<=2.2", "torch"]
extras["all"] = extras["serving"] + ["tensorflow", "torch"]
extras["testing"] = ["pytest", "pytest-xdist", "timeout-decorator", "psutil"]
# sphinx-rtd-theme==0.5.0 introduced big changes in the style.
......@@ -97,7 +97,7 @@ extras["quality"] = [
"isort @ git+git://github.com/timothycrosley/isort.git@e63ae06ec7d70b06df9e528357650281a3d3ec22#egg=isort",
"flake8",
]
extras["dev"] = extras["testing"] + extras["quality"] + ["mecab-python3<1", "scikit-learn", "tensorflow<=2.2", "torch"]
extras["dev"] = extras["testing"] + extras["quality"] + ["mecab-python3<1", "scikit-learn", "tensorflow", "torch"]
setup(
name="transformers",
......
......@@ -141,11 +141,18 @@ class TFMultiHeadAttention(tf.keras.layers.Layer):
return outputs
def point_wise_feed_forward_network(d_model_size, dff, name=""):
return tf.keras.Sequential(
[tf.keras.layers.Dense(dff, activation="relu", name="0"), tf.keras.layers.Dense(d_model_size, name="2")],
name="ffn",
)
class TFPointWiseFeedForwardLayer(tf.keras.layers.Layer):
def __init__(self, d_model_size, dff, **kwargs):
super().__init__(**kwargs)
self.dense_0 = tf.keras.layers.Dense(dff, activation="relu", name="0")
self.dense_2 = tf.keras.layers.Dense(d_model_size, name="2")
def call(self, inputs, trainable=False):
dense_0_output = self.dense_0(inputs)
dense_2_output = self.dense_2(dense_0_output)
return dense_2_output
class TFEncoderLayer(tf.keras.layers.Layer):
......@@ -153,7 +160,7 @@ class TFEncoderLayer(tf.keras.layers.Layer):
super().__init__(**kwargs)
self.multi_head_attention = TFMultiHeadAttention(d_model_size, num_heads, name="multi_head_attention")
self.ffn = point_wise_feed_forward_network(d_model_size, dff, name="ffn")
self.ffn = TFPointWiseFeedForwardLayer(d_model_size, dff, name="ffn")
self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm1")
self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=layer_norm_epsilon, name="layernorm2")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment