Commit 3c227a73 authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Remove __init__ in CachedAttention, which is the same as its parent class.

PiperOrigin-RevId: 306748161
parent 26ea4d1a
......@@ -186,15 +186,9 @@ class MultiHeadAttention(tf.keras.layers.Layer):
class CachedAttention(MultiHeadAttention):
"""Attention layer with cache used for auto-agressive decoding.
Arguments:
num_heads: Number of attention heads.
head_size: Size of each attention head.
**kwargs: Other keyword arguments inherit from `Attention` class.
Arguments are the same as `MultiHeadAttention` layer.
"""
def __init__(self, num_heads, head_size, **kwargs):
super(CachedAttention, self).__init__(num_heads, head_size, **kwargs)
def _update_cache(self, key_tensor, value_tensor, cache, decode_loop_step):
"""Updates cache states and gets full-length key/value tensors."""
# Combines cached keys and values with new keys and values.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment