Commit d493fa8d authored by Sinan Tan's avatar Sinan Tan Committed by xuehui
Browse files

Remove unused classes for SQuAD QA example.

parent dfef39d1
...@@ -197,192 +197,4 @@ class DotAttention: ...@@ -197,192 +197,4 @@ class DotAttention:
''' '''
buf = s * tf.expand_dims(prob, axis=-1) buf = s * tf.expand_dims(prob, axis=-1)
att = tf.reduce_sum(buf, axis=-3) att = tf.reduce_sum(buf, axis=-3)
return att return att
\ No newline at end of file
class MultiHeadAttention:
'''
MultiHeadAttention.
'''
def __init__(self, name, hidden_dim, head, add=True, dot=True, divide=True):
self._name = '/'.join([name, 'dot_att'])
self._head = head
self._head_dim = hidden_dim // head
self._hidden_dim = self._head_dim * head
self._add = add
self._dot = dot
assert add or dot, "you must at least choose one between add and dot"
self._div = 1.0
if divide:
self._div = math.sqrt(self._head_dim)
self._var = {}
@property
def hidden_dim(self):
return self._head_dim * self._head
@property
def name(self):
return self._name
@property
def var(self):
return self._var
def _get_var(self, name, shape, initializer=None):
with tf.variable_scope(self.name):
return _get_variable(self.var, name, shape, initializer)
def _define_params(self, tgt_dim):
self._get_var('tgt_project', [tgt_dim, self._hidden_dim])
self._get_var('tgt_bias', [1, self._hidden_dim])
self._get_var('v', [self._head, self._head_dim, 1])
def get_pre_compute(self, src):
s_shape = src.get_shape().as_list()
src_dim = s_shape[-1]
src_project = self._get_var('src_project', [src_dim, self._hidden_dim])
src_bias = self._get_var('src_bias', [1, self._hidden_dim])
src = split_last_dim(tf.tensordot(src, src_project,
[[2], [0]]) + src_bias, self._head)
return src
def get_prob(self, src, tgt, mask, pre_compute):
'''
:param s: [src_sequence_length, batch_size, src_dim]
:param h: [batch_size, tgt_dim] or [tgt_sequence_length, batch_size, tgt_dim]
:param mask: [src_sequence_length, batch_size]\
or [tgt_sequence_length, src_sequence_length, batch_sizse]
:param pre_compute: [src_sequence_length, batch_size, hidden_dim]
:return: [src_sequence_length, batch_size]\
or [tgt_sequence_length, src_sequence_length, batch_size]
'''
s_shape = src.get_shape().as_list()
h_shape = tgt.get_shape().as_list()
src_dim = s_shape[-1]
tgt_dim = h_shape[-1]
print('src tgt dim: ', src_dim, tgt_dim)
assert src_dim is not None, 'src dimension must be defined'
assert tgt_dim is not None, 'tgt dimension must be defined'
self._define_params(tgt_dim)
if len(h_shape) == 2:
tgt = tf.expand_dims(tgt, 0)
tgt_project = self._var['tgt_project']
tgt_bias = self._var['tgt_bias']
if pre_compute is None:
pre_compute = self.get_pre_compute(src)
src = pre_compute
tgt = split_last_dim(tf.tensordot(tgt, tgt_project,
[[2], [0]]) + tgt_bias, self._head)
add_attention = 0
dot_attention = 0
if self._add:
buf = tf.tanh(tf.expand_dims(src, 0) + tf.expand_dims(tgt, 1))
v = self.var['v']
add_attention = tf.squeeze(batch_linear_layer(buf, v), -1)
if self._dot:
dot_attention = tf.reduce_sum(tf.expand_dims(
src, 0) * tf.expand_dims(tgt, 1), -1)
dot_attention /= self._div
attention = add_attention + dot_attention
mask = tf.expand_dims(mask, -1)
logits = attention + (mask - 1) * 10000.0
prob = tf.nn.softmax(logits, 1)
if len(h_shape) == 2:
prob = tf.squeeze(prob, axis=[0])
return prob
def map_target(self, tgt):
tgt_project = self._var['tgt_project']
tgt_bias = self._var['tgt_bias']
tgt = tf.tensordot(tgt, tgt_project, [[1], [0]]) + tgt_bias
return tgt
def get_att(self, src, prob):
'''
:param s: [src_sequence_length, batch_size, head, head_dim]
:param prob: [src_sequence_length, batch_size, head]\
or [tgt_sequence_length, src_sequence_length, batch_size, head]
:return: [batch_size, src_dim] or [tgt_sequence_length, batch_size, src_dim]
'''
buf = merge_last2_dim(tf.reduce_sum(
src * tf.expand_dims(prob, axis=-1), axis=-4))
return buf
class DotAttentionWrapper(RNNCell):
'''
A wrapper for DotAttention or MultiHeadAttention.
'''
def __init__(self, cell, attention,
src, mask, is_gated,
reuse=None, dropout=None,
keep_input=True, map_target=False):
super().__init__(self, _reuse=reuse)
assert isinstance(attention, (DotAttention, MultiHeadAttention)), \
'type of attention is not supported'
assert isinstance(cell, RNNCell), 'type of cell must be RNNCell'
self._attention = attention
self._src = src
self._mask = mask
self._pre_computed = None
self._is_gated = is_gated
self._cell = cell
self._dropout = dropout
self._keep_input = keep_input
self._map_target = map_target
@property
def state_size(self):
return self._cell.state_size
@property
def output_size(self):
return self._cell.output_size
def call(self, inputs, state):
if self._pre_computed is None:
self._pre_computed = self._attention.get_pre_compute(self._src)
att_prob = self._attention.get_prob(
src=self._src,
tgt=tf.concat([inputs, state], axis=1),
mask=self._mask,
pre_compute=self._pre_computed)
if isinstance(self._attention, DotAttention):
att = self._attention.get_att(self._src, att_prob)
else:
att = self._attention.get_att(self._pre_computed, att_prob)
x_list = [att]
if self._keep_input:
x_list.append(inputs)
if inputs.shape[1] == att.shape[1]:
x_list.append(inputs - att)
x_list.append(inputs * att)
if self._map_target and isinstance(self._attention, MultiHeadAttention):
tgt = self._attention.map_target(
tf.concat([inputs, state], axis=1))
x_list += [tgt, att-tgt, att*tgt]
x = tf.concat(x_list, axis=1)
dim = x.get_shape().as_list()[1]
assert dim is not None, 'dim must be defined'
if self._is_gated:
g = tf.get_variable('att_gate',
shape=[dim, dim],
dtype=tf.float32,
initializer=None)
bias_g = tf.get_variable(
'bias_gate', shape=[1, dim], dtype=tf.float32)
gate = tf.sigmoid(tf.matmul(x, g) + bias_g)
x = x * gate
if self._dropout is not None:
x = self._dropout(x)
return self._cell.call(x, state)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment