"git@developer.sourcefind.cn:zhaoyu6/sglang.git" did not exist on "99c92ff24bc82bf54d4cdb4106cb3966f3fa31de"
Commit 09c5ae2f authored by Hongkun Yu's avatar Hongkun Yu Committed by A. Unique TensorFlower
Browse files

Internal change

PiperOrigin-RevId: 310767440
parent 52e4ded8
# Lint as: python3
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
......@@ -19,12 +20,98 @@ from __future__ import division
# from __future__ import google_type_annotations
from __future__ import print_function
import collections
import math
import string
import numpy as np
import tensorflow as tf
from official.nlp.modeling.layers import dense_einsum
from official.nlp.modeling.layers import masked_softmax
EinsumDense = tf.keras.layers.experimental.EinsumDense
_CHR_IDX = string.ascii_lowercase
def _build_attention_equation(qkv_rank, attn_axes):
"""Builds einsum equations for the attention computation.
Query, key, value inputs after projection are expected to have the shape as:
(bs, <non-attention dims>, <attention dims>, num_heads, channels).
bs and <non-attention dims> are treated as <batch dims>.
The attention operations can be generalized:
(1) Query-key dot product:
(<batch dims>, <query attention dims>, num_heads, channels), (<batch dims>,
<key attention dims>, num_heads, channels) -> (<batch dims>,
num_heads, <query attention dims>, <key attention dims>)
(2) Combination:
(<batch dims>, num_heads, <query attention dims>, <key attention dims>),
(<batch dims>, <value attention dims>, num_heads, channels) -> (<batch dims>,
<query attention dims>, num_heads, channels)
Args:
qkv_rank: the rank of query, key, value tensors.
attn_axes: a list/tuple of axes, [1, rank), that will do attention.
Returns:
Einsum equations.
"""
target_notation = _CHR_IDX[:qkv_rank]
# `batch_dims` includes the head dim.
batch_dims = tuple(np.delete(range(qkv_rank), attn_axes + (qkv_rank - 1,)))
letter_offset = qkv_rank
source_notation = ""
for i in range(qkv_rank):
if i in batch_dims or i == qkv_rank - 1:
source_notation += target_notation[i]
else:
source_notation += _CHR_IDX[letter_offset]
letter_offset += 1
product_notation = "".join([target_notation[i] for i in batch_dims] +
[target_notation[i] for i in attn_axes] +
[source_notation[i] for i in attn_axes])
dot_product_equation = "%s,%s->%s" % (source_notation, target_notation,
product_notation)
combine_equation = "%s,%s->%s" % (product_notation, source_notation,
target_notation)
return dot_product_equation, combine_equation
def _build_proj_equation(free_dims, bound_dims, output_dims):
"""Builds an einsum equation for projections inside multi-head attention."""
input_str = ""
kernel_str = ""
output_str = ""
bias_axes = ""
letter_offset = 0
for i in range(free_dims):
char = _CHR_IDX[i + letter_offset]
input_str += char
output_str += char
letter_offset += free_dims
for i in range(bound_dims):
char = _CHR_IDX[i + letter_offset]
input_str += char
kernel_str += char
letter_offset += bound_dims
for i in range(output_dims):
char = _CHR_IDX[i + letter_offset]
kernel_str += char
output_str += char
bias_axes += char
equation = "%s,%s->%s" % (input_str, kernel_str, output_str)
# The output rank does not consider the batch dimension.
output_rank = len(output_str) - 1
return equation, bias_axes, output_rank
def _get_output_shape(output_rank, known_last_dims):
return [None] * (output_rank - len(known_last_dims)) + list(known_last_dims)
@tf.keras.utils.register_keras_serializable(package="Text")
class MultiHeadAttention(tf.keras.layers.Layer):
......@@ -53,7 +140,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
key_size: Size of each attention head for query and key.
value_size: Size of each attention head for value.
dropout: Dropout probability.
use_bias: Boolean, whether the dense layers use bias vectors.
use_bias: Boolean, whether the dense layers use bias vectors/matrices.
output_shape: The expected shape of an output tensor, besides the batch and
sequence dims. If not specified, projects back to the key feature dim.
kernel_initializer: Initializer for dense layer kernels.
......@@ -94,44 +181,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
self._kernel_constraint = tf.keras.constraints.get(kernel_constraint)
self._bias_constraint = tf.keras.constraints.get(bias_constraint)
self._query_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._key_size),
use_bias=self._use_bias,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="query")
self._key_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._key_size),
use_bias=self._use_bias,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="key")
self._value_dense = dense_einsum.DenseEinsum(
output_shape=(self._num_heads, self._value_size),
use_bias=self._use_bias,
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="value")
self._masked_softmax = masked_softmax.MaskedSoftmax(mask_expansion_axes=[1])
self._dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
def get_config(self):
......@@ -167,22 +217,72 @@ class MultiHeadAttention(tf.keras.layers.Layer):
return dict(list(base_config.items()) + list(config.items()))
def build(self, input_shape):
if self._output_shape:
output_shape = self._output_shape
else:
input_shape = tf.TensorShape(input_shape[0])
output_shape = input_shape[-1]
self._output_dense = dense_einsum.DenseEinsum(
output_shape=output_shape,
num_summed_dimensions=2,
inputs_len = len(input_shape)
if inputs_len > 3 or inputs_len < 2:
raise ValueError(
"Expects inputs list of length 2 or 3, namely [query, value] or "
"[query, value, key]. "
"Given length: %d" % inputs_len)
tensor_shapes = tf.nest.map_structure(tf.TensorShape, input_shape)
query_shape = tensor_shapes[0]
value_shape = tensor_shapes[1]
key_shape = tensor_shapes[2] if inputs_len == 3 else value_shape
common_kwargs = dict(
kernel_initializer=self._kernel_initializer,
bias_initializer=self._bias_initializer,
kernel_regularizer=self._kernel_regularizer,
bias_regularizer=self._bias_regularizer,
activity_regularizer=self._activity_regularizer,
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="attention_output")
bias_constraint=self._bias_constraint)
free_dims = query_shape.rank - 1
einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=1, output_dims=2)
self._query_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank,
[self._num_heads, self._key_size]),
bias_axes=bias_axes if self._use_bias else None,
name="query",
**common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
key_shape.rank - 1, bound_dims=1, output_dims=2)
self._key_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank,
[self._num_heads, self._key_size]),
bias_axes=bias_axes if self._use_bias else None,
name="key",
**common_kwargs)
einsum_equation, bias_axes, output_rank = _build_proj_equation(
value_shape.rank - 1, bound_dims=1, output_dims=2)
self._value_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank,
[self._num_heads, self._value_size]),
bias_axes=bias_axes if self._use_bias else None,
name="value",
**common_kwargs)
self._dot_product_equation, self._combine_equation = (
_build_attention_equation(output_rank + 1, attn_axes=(1,)))
if self._output_shape:
if not isinstance(self._output_shape, collections.abc.Sized):
output_shape = [self._output_shape]
else:
output_shape = self._output_shape
else:
output_shape = [query_shape[-1]]
einsum_equation, bias_axes, output_rank = _build_proj_equation(
free_dims, bound_dims=2, output_dims=len(output_shape))
self._output_dense = EinsumDense(
einsum_equation,
output_shape=_get_output_shape(output_rank, output_shape),
bias_axes=bias_axes if self._use_bias else None,
name="attention_output",
**common_kwargs)
super(MultiHeadAttention, self).build(input_shape)
def call(self, inputs, attention_mask=None):
......@@ -234,7 +334,8 @@ class MultiHeadAttention(tf.keras.layers.Layer):
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores = tf.einsum("BSNH,BTNH->BNTS", key_tensor, query_tensor)
attention_scores = tf.einsum(self._dot_product_equation, key_tensor,
query_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._key_size)))
......@@ -247,7 +348,7 @@ class MultiHeadAttention(tf.keras.layers.Layer):
attention_probs = self._dropout(attention_probs)
# `context_layer` = [B, T, N, H]
attention_output = tf.einsum("BNTS,BSNH->BTNH", attention_probs,
attention_output = tf.einsum(self._combine_equation, attention_probs,
value_tensor)
attention_output = self._output_dense(attention_output)
......@@ -288,11 +389,14 @@ class CachedAttention(MultiHeadAttention):
return key_tensor, value_tensor
def call(self, inputs, decode_loop_step=None):
def call(self,
inputs,
attention_mask=None,
cache=None,
decode_loop_step=None):
from_tensor = inputs[0]
to_tensor = inputs[1]
attention_mask = inputs[2] if len(inputs) >= 3 else None
cache = inputs[3] if len(inputs) >= 4 else None
# Scalar dimensions referenced here:
# B = batch size (number of sequences)
# F = `from_tensor` sequence length
......@@ -314,7 +418,8 @@ class CachedAttention(MultiHeadAttention):
# Take the dot product between "query" and "key" to get the raw
# attention scores.
attention_scores = tf.einsum("BTNH,BFNH->BNFT", key_tensor, query_tensor)
attention_scores = tf.einsum(self._dot_product_equation, key_tensor,
query_tensor)
attention_scores = tf.multiply(attention_scores,
1.0 / math.sqrt(float(self._key_size)))
......@@ -326,7 +431,7 @@ class CachedAttention(MultiHeadAttention):
# seem a bit unusual, but is taken from the original Transformer paper.
attention_probs = self._dropout(attention_probs)
# `context_layer` = [B, F, N, H]
attention_output = tf.einsum("BNFT,BTNH->BFNH", attention_probs,
attention_output = tf.einsum(self._combine_equation, attention_probs,
value_tensor)
attention_output = self._output_dense(attention_output)
return attention_output, cache
......@@ -99,6 +99,13 @@ class MultiHeadAttentionTest(keras_parameterized.TestCase):
# same.
self.assertNotAllClose(masked_output_data, unmasked_output_data)
if use_bias:
self.assertLen(test_layer._query_dense.trainable_variables, 2)
self.assertLen(test_layer._output_dense.trainable_variables, 2)
else:
self.assertLen(test_layer._query_dense.trainable_variables, 1)
self.assertLen(test_layer._output_dense.trainable_variables, 1)
def test_initializer(self):
"""Test with a specified initializer."""
test_layer = attention.MultiHeadAttention(
......@@ -143,7 +150,7 @@ class CachedAttentionTest(keras_parameterized.TestCase):
# one element.
mask_data = np.random.randint(
2, size=(batch_size, from_seq_length, from_seq_length))
masked_output_data, cache = layer([from_data, from_data, mask_data, cache])
masked_output_data, cache = layer([from_data, from_data], mask_data, cache)
self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
......@@ -170,7 +177,9 @@ class CachedAttentionTest(keras_parameterized.TestCase):
mask_data = np.random.randint(
2, size=(batch_size, from_seq_length, from_seq_length), dtype=np.int32)
# Testing the invocation directly as Keras cannot consume inputs correctly.
masked_output_data, cache = layer([from_data, from_data, mask_data, cache],
masked_output_data, cache = layer([from_data, from_data],
mask_data,
cache,
decode_loop_step=decode_loop_step)
self.assertEqual(masked_output_data.shape, (3, 4, 8))
self.assertEqual(cache["value"].shape, (3, 4, 2, 2))
......
......@@ -116,10 +116,6 @@ class Transformer(tf.keras.layers.Layer):
kernel_constraint=self._kernel_constraint,
bias_constraint=self._bias_constraint,
name="self_attention")
# TODO(hongkuny): Remove when checkpoint backward compatibility is resolved.
# pylint: disable=protected-access
self._attention_layer.build([input_tensor_shape])
self._attention_output_dense = self._attention_layer._output_dense
self._attention_dropout = tf.keras.layers.Dropout(rate=self._dropout_rate)
# Use float32 in layernorm for numeric stability.
......
......@@ -95,12 +95,6 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
output_shape=self.hidden_size,
kernel_initializer=self._kernel_initializer,
name="attention/encdec")
# TODO(hongkuny): Remove when checkpoint backward compatibility is resolved.
# pylint: disable=protected-access
self.self_attention.build(input_shape)
self.self_attention_output_dense = self.self_attention._output_dense
self.encdec_attention.build(input_shape)
self.encdec_attention_output_dense = self.encdec_attention._output_dense
self.encdec_attention_dropout = tf.keras.layers.Dropout(
rate=self.hidden_dropout_prob)
......@@ -145,14 +139,12 @@ class TransformerDecoderBlock(tf.keras.layers.Layer):
"TransformerDecoderBlock must have 4 inputs, but it got: %d" %
len(inputs))
input_tensor, memory, attention_mask, self_attention_mask = inputs[:4]
if cache is None:
self_attention_inputs = [input_tensor, input_tensor, self_attention_mask]
else:
self_attention_inputs = [
input_tensor, input_tensor, self_attention_mask, cache
]
self_attention_inputs = [input_tensor, input_tensor]
self_attention_output, cache = self.self_attention(
self_attention_inputs, decode_loop_step=decode_loop_step)
self_attention_inputs,
attention_mask=self_attention_mask,
cache=cache,
decode_loop_step=decode_loop_step)
self_attention_output = self.self_attention_dropout(self_attention_output)
self_attention_output = self.self_attention_layer_norm(
input_tensor + self_attention_output)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment