Commit d39321b1 authored by xinliupitt's avatar xinliupitt
Browse files

docstrings

parent 23804bc5
......@@ -521,9 +521,10 @@ class CachedAttention(MultiHeadAttention):
if cache:
key, value = self._update_cache(key, value, cache, decode_loop_step)
query = tf.multiply(query,1.0 / math.sqrt(float(self._key_size)))
# Take the dot product between "query" and "key" to get the raw
# attention scores.
query = tf.multiply(query,1.0 / math.sqrt(float(self._key_size)))
attention_scores = tf.einsum(self._dot_product_equation, key, query)
# Normalize the attention scores to probabilities.
......
......@@ -49,6 +49,10 @@ class Transformer(tf.keras.layers.Layer):
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers.
norm_epsilon: Epsilon value to initialize normalization layers.
"""
def __init__(self,
......@@ -277,6 +281,10 @@ class TransformerDecoderLayer(tf.keras.layers.Layer):
activity_regularizer: Regularizer for dense layer activity.
kernel_constraint: Constraint for dense layer kernels.
bias_constraint: Constraint for dense layer kernels.
use_bias: Whether to enable use_bias in attention layer.
norm_first: Whether to normalize inputs to attention and intermediate dense
layers.
norm_epsilon: Epsilon value to initialize normalization layers.
"""
def __init__(self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment