Unverified Commit aad41340 authored by Hongkun Yu's avatar Hongkun Yu Committed by GitHub
Browse files

Merged commit includes the following changes: (#7324)

260601376  by hongkuny<hongkuny@google.com>:

    reorder Q,K to make TPU faster.

--

PiperOrigin-RevId: 260601376
parent d65af7d8
......@@ -365,7 +365,7 @@ class Attention(tf.keras.layers.Layer):
Q:[BFNH] = einsum('BFD,DNH->BFNH', Input_tensor, Wq)
K:[BTNH] = einsum('BTD,DNH->BTNH', Input_tensor, Wk)
V:[BTNH] = einsum('BTD,DNH->BTNH', Input_tensor, Wv)
attention_scores:[BNFT] = einsum('BFNH,BTNH>BNFT', Q, K) / sqrt(H)
attention_scores:[BNFT] = einsum('BTNH,BFNH->BNFT', K, Q) / sqrt(H)
attention_probs:[BNFT] = softmax(attention_scores)
context_layer:[BFNH] = einsum('BNFT,BTNH->BFNH', attention_probs, V)
Wout:[DNH]
......@@ -433,7 +433,7 @@ class Attention(tf.keras.layers.Layer):
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores = tf.einsum("BFNH,BTNH->BNFT", query_tensor, key_tensor)
attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_tensor, query_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self.size_per_head)))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment