Commit ed54e8dd authored by Andrew M. Dai's avatar Andrew M. Dai
Browse files

Change logit and activation shapes to follow standard convention with batch...

Change logit and activation shapes to follow standard convention with batch size in the first dimension.

PiperOrigin-RevId: 172172495
parent 1024f926
...@@ -51,27 +51,16 @@ class VatxtInput(object): ...@@ -51,27 +51,16 @@ class VatxtInput(object):
batch.sequences[data_utils.SequenceWrapper.F_TOKEN_ID]) batch.sequences[data_utils.SequenceWrapper.F_TOKEN_ID])
self._num_states = num_states self._num_states = num_states
# Once the tokens have passed through embedding and LSTM, the output Tensor
# shapes will be time-major, i.e. shape = (time, batch, dim). Here we make
# both weights and labels time-major with a transpose, and then merge the
# time and batch dimensions such that they are both vectors of shape
# (time*batch).
w = batch.sequences[data_utils.SequenceWrapper.F_WEIGHT] w = batch.sequences[data_utils.SequenceWrapper.F_WEIGHT]
w = tf.transpose(w, [1, 0])
w = tf.reshape(w, [-1])
self._weights = w self._weights = w
l = batch.sequences[data_utils.SequenceWrapper.F_LABEL] l = batch.sequences[data_utils.SequenceWrapper.F_LABEL]
l = tf.transpose(l, [1, 0])
l = tf.reshape(l, [-1])
self._labels = l self._labels = l
# eos weights # eos weights
self._eos_weights = None self._eos_weights = None
if eos_id: if eos_id:
ew = tf.cast(tf.equal(self._tokens, eos_id), tf.float32) ew = tf.cast(tf.equal(self._tokens, eos_id), tf.float32)
ew = tf.transpose(ew, [1, 0])
ew = tf.reshape(ew, [-1])
self._eos_weights = ew self._eos_weights = ew
@property @property
......
...@@ -125,10 +125,9 @@ class LSTM(object): ...@@ -125,10 +125,9 @@ class LSTM(object):
lstm_out, next_state = tf.contrib.rnn.static_rnn( lstm_out, next_state = tf.contrib.rnn.static_rnn(
cell, x, initial_state=initial_state, sequence_length=seq_length) cell, x, initial_state=initial_state, sequence_length=seq_length)
# Merge time and batch dimensions lstm_out = tf.stack(lstm_out)
# shape(lstm_out) = timesteps * (batch_size, cell_size) # shape(lstm_out) = (timesteps, batch_size, cell_size)
lstm_out = tf.concat(lstm_out, 0) lstm_out = tf.transpose(lstm_out, [1, 0, 2])
# shape(lstm_out) = (timesteps*batch_size, cell_size)
if self.keep_prob < 1.: if self.keep_prob < 1.:
lstm_out = tf.nn.dropout(lstm_out, self.keep_prob) lstm_out = tf.nn.dropout(lstm_out, self.keep_prob)
...@@ -172,23 +171,28 @@ class SoftmaxLoss(K.layers.Layer): ...@@ -172,23 +171,28 @@ class SoftmaxLoss(K.layers.Layer):
x, labels, weights = inputs x, labels, weights = inputs
if self.num_candidate_samples > -1: if self.num_candidate_samples > -1:
assert self.vocab_freqs is not None assert self.vocab_freqs is not None
labels = tf.expand_dims(labels, -1) labels_reshaped = tf.reshape(labels, [-1])
labels_reshaped = tf.expand_dims(labels_reshaped, -1)
sampled = tf.nn.fixed_unigram_candidate_sampler( sampled = tf.nn.fixed_unigram_candidate_sampler(
true_classes=labels, true_classes=labels_reshaped,
num_true=1, num_true=1,
num_sampled=self.num_candidate_samples, num_sampled=self.num_candidate_samples,
unique=True, unique=True,
range_max=self.vocab_size, range_max=self.vocab_size,
unigrams=self.vocab_freqs) unigrams=self.vocab_freqs)
inputs_reshaped = tf.reshape(x, [-1, int(x.get_shape()[2])])
lm_loss = tf.nn.sampled_softmax_loss( lm_loss = tf.nn.sampled_softmax_loss(
weights=tf.transpose(self.lin_w), weights=tf.transpose(self.lin_w),
biases=self.lin_b, biases=self.lin_b,
labels=labels, labels=labels_reshaped,
inputs=x, inputs=inputs_reshaped,
num_sampled=self.num_candidate_samples, num_sampled=self.num_candidate_samples,
num_classes=self.vocab_size, num_classes=self.vocab_size,
sampled_values=sampled) sampled_values=sampled)
lm_loss = tf.reshape(
lm_loss,
[int(x.get_shape()[0]), int(x.get_shape()[1])])
else: else:
logits = tf.matmul(x, self.lin_w) + self.lin_b logits = tf.matmul(x, self.lin_w) + self.lin_b
lm_loss = tf.nn.sparse_softmax_cross_entropy_with_logits( lm_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment