".github/git@developer.sourcefind.cn:wangsen/mineru.git" did not exist on "7b61571cc58a72afa8b430e7cd7d2ac468bd073d"
Commit 5c535343 authored by Neal Wu's avatar Neal Wu
Browse files

Manually fixed many occurrences of tf.split

parent fdc0c4ab
...@@ -211,7 +211,7 @@ def reorder_beam(beam_size, batch_size, beam_val, output, is_first, ...@@ -211,7 +211,7 @@ def reorder_beam(beam_size, batch_size, beam_val, output, is_first,
# beam_val is [batch_size x beam_size]; let b = batch_size * beam_size # beam_val is [batch_size x beam_size]; let b = batch_size * beam_size
# decided is len x b x a x b # decided is len x b x a x b
# output is b x out_size; step is b x len x a x b; # output is b x out_size; step is b x len x a x b;
outputs = tf.split(axis=tf.nn.log_softmax(output), num_or_size_splits=beam_size, value=0) outputs = tf.split(axis=0, num_or_size_splits=beam_size, value=tf.nn.log_softmax(output))
all_beam_vals, all_beam_idx = [], [] all_beam_vals, all_beam_idx = [], []
beam_range = 1 if is_first else beam_size beam_range = 1 if is_first else beam_size
for i in xrange(beam_range): for i in xrange(beam_range):
...@@ -266,9 +266,9 @@ class NeuralGPU(object): ...@@ -266,9 +266,9 @@ class NeuralGPU(object):
self.input = tf.placeholder(tf.int32, name="inp") self.input = tf.placeholder(tf.int32, name="inp")
self.target = tf.placeholder(tf.int32, name="tgt") self.target = tf.placeholder(tf.int32, name="tgt")
self.prev_step = tf.placeholder(tf.float32, name="prev_step") self.prev_step = tf.placeholder(tf.float32, name="prev_step")
gpu_input = tf.split(axis=self.input, num_or_size_splits=num_gpus, value=0) gpu_input = tf.split(axis=0, num_or_size_splits=num_gpus, value=self.input)
gpu_target = tf.split(axis=self.target, num_or_size_splits=num_gpus, value=0) gpu_target = tf.split(axis=0, num_or_size_splits=num_gpus, value=self.target)
gpu_prev_step = tf.split(axis=self.prev_step, num_or_size_splits=num_gpus, value=0) gpu_prev_step = tf.split(axis=0, num_or_size_splits=num_gpus, value=self.prev_step)
batch_size = tf.shape(gpu_input[0])[0] batch_size = tf.shape(gpu_input[0])[0]
if backward: if backward:
......
...@@ -332,7 +332,7 @@ def masked_conv_aff_coupling(input_, mask_in, dim, name, ...@@ -332,7 +332,7 @@ def masked_conv_aff_coupling(input_, mask_in, dim, name,
residual_blocks=residual_blocks, residual_blocks=residual_blocks,
bottleneck=bottleneck, skip=skip) bottleneck=bottleneck, skip=skip)
mask = tf.mod(mask_channel + mask, 2) mask = tf.mod(mask_channel + mask, 2)
res = tf.split(axis=res, num_or_size_splits=2, value=3) res = tf.split(axis=3, num_or_size_splits=2, value=res)
shift, log_rescaling = res[-2], res[-1] shift, log_rescaling = res[-2], res[-1]
scale = variable_on_cpu( scale = variable_on_cpu(
"rescaling_scale", [], "rescaling_scale", [],
...@@ -486,9 +486,9 @@ def conv_ch_aff_coupling(input_, dim, name, ...@@ -486,9 +486,9 @@ def conv_ch_aff_coupling(input_, dim, name,
scope.reuse_variables() scope.reuse_variables()
if change_bottom: if change_bottom:
input_, canvas = tf.split(axis=input_, num_or_size_splits=2, value=3) input_, canvas = tf.split(axis=3, num_or_size_splits=2, value=input_)
else: else:
canvas, input_ = tf.split(axis=input_, num_or_size_splits=2, value=3) canvas, input_ = tf.split(axis=3, num_or_size_splits=2, value=input_)
shape = input_.get_shape().as_list() shape = input_.get_shape().as_list()
batch_size = shape[0] batch_size = shape[0]
height = shape[1] height = shape[1]
...@@ -509,7 +509,7 @@ def conv_ch_aff_coupling(input_, dim, name, ...@@ -509,7 +509,7 @@ def conv_ch_aff_coupling(input_, dim, name,
train=train, weight_norm=weight_norm, train=train, weight_norm=weight_norm,
residual_blocks=residual_blocks, residual_blocks=residual_blocks,
bottleneck=bottleneck, skip=skip) bottleneck=bottleneck, skip=skip)
shift, log_rescaling = tf.split(axis=res, num_or_size_splits=2, value=3) shift, log_rescaling = tf.split(axis=3, num_or_size_splits=2, value=res)
scale = variable_on_cpu( scale = variable_on_cpu(
"scale", [], "scale", [],
tf.constant_initializer(1.)) tf.constant_initializer(1.))
...@@ -570,9 +570,9 @@ def conv_ch_add_coupling(input_, dim, name, ...@@ -570,9 +570,9 @@ def conv_ch_add_coupling(input_, dim, name,
scope.reuse_variables() scope.reuse_variables()
if change_bottom: if change_bottom:
input_, canvas = tf.split(axis=input_, num_or_size_splits=2, value=3) input_, canvas = tf.split(axis=3, num_or_size_splits=2, value=input_)
else: else:
canvas, input_ = tf.split(axis=input_, num_or_size_splits=2, value=3) canvas, input_ = tf.split(axis=3, num_or_size_splits=2, value=input_)
shape = input_.get_shape().as_list() shape = input_.get_shape().as_list()
channels = shape[3] channels = shape[3]
res = input_ res = input_
...@@ -736,8 +736,8 @@ def rec_masked_conv_coupling(input_, hps, scale_idx, n_scale, ...@@ -736,8 +736,8 @@ def rec_masked_conv_coupling(input_, hps, scale_idx, n_scale,
log_diff_1 = log_diff[:, :, :, :channels] log_diff_1 = log_diff[:, :, :, :channels]
log_diff_2 = log_diff[:, :, :, channels:] log_diff_2 = log_diff[:, :, :, channels:]
else: else:
res_1, res_2 = tf.split(axis=res, num_or_size_splits=2, value=3) res_1, res_2 = tf.split(axis=3, num_or_size_splits=2, value=res)
log_diff_1, log_diff_2 = tf.split(axis=log_diff, num_or_size_splits=2, value=3) log_diff_1, log_diff_2 = tf.split(axis=3, num_or_size_splits=2, value=log_diff)
res_1, inc_log_diff = rec_masked_conv_coupling( res_1, inc_log_diff = rec_masked_conv_coupling(
input_=res_1, hps=hps, scale_idx=scale_idx + 1, n_scale=n_scale, input_=res_1, hps=hps, scale_idx=scale_idx + 1, n_scale=n_scale,
use_batch_norm=use_batch_norm, weight_norm=weight_norm, use_batch_norm=use_batch_norm, weight_norm=weight_norm,
...@@ -798,8 +798,8 @@ def rec_masked_deconv_coupling(input_, hps, scale_idx, n_scale, ...@@ -798,8 +798,8 @@ def rec_masked_deconv_coupling(input_, hps, scale_idx, n_scale,
log_diff_1 = log_diff[:, :, :, :channels] log_diff_1 = log_diff[:, :, :, :channels]
log_diff_2 = log_diff[:, :, :, channels:] log_diff_2 = log_diff[:, :, :, channels:]
else: else:
res_1, res_2 = tf.split(axis=res, num_or_size_splits=2, value=3) res_1, res_2 = tf.split(axis=3, num_or_size_splits=2, value=res)
log_diff_1, log_diff_2 = tf.split(axis=log_diff, num_or_size_splits=2, value=3) log_diff_1, log_diff_2 = tf.split(axis=3, num_or_size_splits=2, value=log_diff)
res_1, log_diff_1 = rec_masked_deconv_coupling( res_1, log_diff_1 = rec_masked_deconv_coupling(
input_=res_1, hps=hps, input_=res_1, hps=hps,
scale_idx=scale_idx + 1, n_scale=n_scale, scale_idx=scale_idx + 1, n_scale=n_scale,
...@@ -1305,7 +1305,7 @@ class RealNVP(object): ...@@ -1305,7 +1305,7 @@ class RealNVP(object):
z_lost = z_complete z_lost = z_complete
for scale_idx in xrange(hps.n_scale - 1): for scale_idx in xrange(hps.n_scale - 1):
z_lost = squeeze_2x2_ordered(z_lost) z_lost = squeeze_2x2_ordered(z_lost)
z_lost, _ = tf.split(axis=z_lost, num_or_size_splits=2, value=3) z_lost, _ = tf.split(axis=3, num_or_size_splits=2, value=z_lost)
z_compressed = z_lost z_compressed = z_lost
z_noisy = z_lost z_noisy = z_lost
for _ in xrange(scale_idx + 1): for _ in xrange(scale_idx + 1):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment