Commit ca2da2ee authored by Arkanath Pathak's avatar Arkanath Pathak
Browse files

Added ptn directory

parent 2a5f2a95
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Training decoder as used in PTN (NIPS16)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
slim = tf.contrib.slim
@tf.contrib.framework.add_arg_scope
def conv3d_transpose(inputs,
num_outputs,
kernel_size,
stride=1,
padding='SAME',
activation_fn=tf.nn.relu,
weights_initializer=tf.contrib.layers.xavier_initializer(),
biases_initializer=tf.zeros_initializer(),
reuse=None,
trainable=True,
scope=None):
"""Wrapper for conv3d_transpose layer.
This function wraps the tf.conv3d_transpose with basic non-linearity.
Tt creates a variable called `weights`, representing the kernel,
that is convoled with the input. A second varibale called `biases'
is added to the result of operation.
"""
with tf.variable_scope(
scope, 'Conv3d_transpose', [inputs], reuse=reuse):
dtype = inputs.dtype.base_dtype
kernel_d, kernel_h, kernel_w = kernel_size[0:3]
num_filters_in = inputs.get_shape()[4]
weights_shape = [kernel_d, kernel_h, kernel_w, num_outputs, num_filters_in]
weights = tf.get_variable('weights',
shape=weights_shape,
dtype=dtype,
initializer=weights_initializer,
trainable=trainable)
tf.contrib.framework.add_model_variable(weights)
input_shape = inputs.get_shape().as_list()
batch_size = input_shape[0]
depth = input_shape[1]
height = input_shape[2]
width = input_shape[3]
def get_deconv_dim(dim_size, stride_size):
# Only support padding='SAME'.
if isinstance(dim_size, tf.Tensor):
dim_size = tf.multiply(dim_size, stride_size)
elif dim_size is not None:
dim_size *= stride_size
return dim_size
out_depth = get_deconv_dim(depth, stride)
out_height = get_deconv_dim(height, stride)
out_width = get_deconv_dim(width, stride)
out_shape = [batch_size, out_depth, out_height, out_width, num_outputs]
outputs = tf.nn.conv3d_transpose(inputs, weights, out_shape,
[1, stride, stride, stride, 1],
padding=padding)
outputs.set_shape(out_shape)
if biases_initializer is not None:
biases = tf.get_variable('biases',
shape=[num_outputs,],
dtype=dtype,
initializer=biases_initializer,
trainable=trainable)
tf.contrib.framework.add_model_variable(biases)
outputs = tf.nn.bias_add(outputs, biases)
if activation_fn:
outputs = activation_fn(outputs)
return outputs
def model(identities, params, is_training):
"""Model transforming embedding to voxels."""
del is_training # Unused
f_dim = params.f_dim
# Please refer to the original implementation: github.com/xcyan/nips16_PTN
# In TF replication, we use a slightly different architecture.
with slim.arg_scope(
[slim.fully_connected, conv3d_transpose],
weights_initializer=tf.truncated_normal_initializer(stddev=0.02, seed=1)):
h0 = slim.fully_connected(
identities, 4 * 4 * 4 * f_dim * 8, activation_fn=tf.nn.relu)
h1 = tf.reshape(h0, [-1, 4, 4, 4, f_dim * 8])
h1 = conv3d_transpose(
h1, f_dim * 4, [4, 4, 4], stride=2, activation_fn=tf.nn.relu)
h2 = conv3d_transpose(
h1, int(f_dim * 3 / 2), [5, 5, 5], stride=2, activation_fn=tf.nn.relu)
h3 = conv3d_transpose(
h2, 1, [6, 6, 6], stride=2, activation_fn=tf.nn.sigmoid)
return h3
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains training plan for the Rotator model (Pretraining in NIPS16)."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
from tensorflow import app
import model_rotator as model
flags = tf.app.flags
slim = tf.contrib.slim
flags.DEFINE_string('inp_dir', '',
'Directory path containing the input data (tfrecords).')
flags.DEFINE_string(
'dataset_name', 'shapenet_chair',
'Dataset name that is to be used for training and evaluation.')
flags.DEFINE_integer('z_dim', 512, '')
flags.DEFINE_integer('a_dim', 3, '')
flags.DEFINE_integer('f_dim', 64, '')
flags.DEFINE_integer('fc_dim', 1024, '')
flags.DEFINE_integer('num_views', 24, 'Num of viewpoints in the input data.')
flags.DEFINE_integer('image_size', 64,
'Input images dimension (pixels) - width & height.')
flags.DEFINE_integer('step_size', 1, 'Steps to take for rotation in pretraining.')
flags.DEFINE_integer('batch_size', 32, 'Batch size for training.')
flags.DEFINE_string('encoder_name', 'ptn_encoder',
'Name of the encoder network being used.')
flags.DEFINE_string('decoder_name', 'ptn_im_decoder',
'Name of the decoder network being used.')
flags.DEFINE_string('rotator_name', 'ptn_rotator',
'Name of the rotator network being used.')
# Save options
flags.DEFINE_string('checkpoint_dir', '/tmp/ptn_train/',
'Directory path for saving trained models and other data.')
flags.DEFINE_string('model_name', 'deeprotator_pretrain',
'Name of the model used in naming the TF job. Must be different for each run.')
flags.DEFINE_string('init_model', None,
'Checkpoint path of the model to initialize with.')
flags.DEFINE_integer('save_every', 1000,
'Average period of steps after which we save a model.')
# Optimization
flags.DEFINE_float('image_weight', 10, 'Weighting factor for image loss.')
flags.DEFINE_float('mask_weight', 1, 'Weighting factor for mask loss.')
flags.DEFINE_float('learning_rate', 0.0001, 'Learning rate.')
flags.DEFINE_float('weight_decay', 0.001, 'Weight decay parameter while training.')
flags.DEFINE_float('clip_gradient_norm', 0, 'Gradient clim norm, leave 0 if no gradient clipping.')
flags.DEFINE_integer('max_number_of_steps', 320000, 'Maximum number of steps for training.')
# Summary
flags.DEFINE_integer('save_summaries_secs', 15, 'Seconds interval for dumping TF summaries.')
flags.DEFINE_integer('save_interval_secs', 60 * 5, 'Seconds interval to save models.')
# Distribution
flags.DEFINE_string('master', '', 'The address of the tensorflow master if running distributed.')
flags.DEFINE_bool('sync_replicas', False, 'Whether to sync gradients between replicas for optimizer.')
flags.DEFINE_integer('worker_replicas', 1, 'Number of worker replicas (train tasks).')
flags.DEFINE_integer('backup_workers', 0, 'Number of backup workers.')
flags.DEFINE_integer('ps_tasks', 0, 'Number of ps tasks.')
flags.DEFINE_integer('task', 0,
'Task identifier flag to be set for each task running in distributed manner. Task number 0 '
'will be chosen as the chief.')
FLAGS = flags.FLAGS
def main(_):
train_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name, 'train')
save_image_dir = os.path.join(train_dir, 'images')
if not os.path.exists(train_dir):
os.makedirs(train_dir)
if not os.path.exists(save_image_dir):
os.makedirs(save_image_dir)
g = tf.Graph()
with g.as_default():
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
global_step = slim.get_or_create_global_step()
##########
## data ##
##########
train_data = model.get_inputs(
FLAGS.inp_dir,
FLAGS.dataset_name,
'train',
FLAGS.batch_size,
FLAGS.image_size,
is_training=True)
inputs = model.preprocess(train_data, FLAGS.step_size)
###########
## model ##
###########
model_fn = model.get_model_fn(FLAGS, is_training=True)
outputs = model_fn(inputs)
##########
## loss ##
##########
task_loss = model.get_loss(inputs, outputs, FLAGS)
regularization_loss = model.get_regularization_loss(
['encoder', 'rotator', 'decoder'], FLAGS)
loss = task_loss + regularization_loss
###############
## optimizer ##
###############
optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
if FLAGS.sync_replicas:
optimizer = tf.train.SyncReplicasOptimizer(
optimizer,
replicas_to_aggregate=FLAGS.workers_replicas - FLAGS.backup_workers,
total_num_replicas=FLAGS.worker_replicas)
##############
## train_op ##
##############
train_op = model.get_train_op_for_scope(
loss, optimizer, ['encoder', 'rotator', 'decoder'], FLAGS)
###########
## saver ##
###########
saver = tf.train.Saver(max_to_keep=np.minimum(5,
FLAGS.worker_replicas + 1))
if FLAGS.task == 0:
val_data = model.get_inputs(
FLAGS.inp_dir,
FLAGS.dataset_name,
'val',
FLAGS.batch_size,
FLAGS.image_size,
is_training=False)
val_inputs = model.preprocess(val_data, FLAGS.step_size)
# Note: don't compute loss here
reused_model_fn = model.get_model_fn(
FLAGS, is_training=False, reuse=True)
val_outputs = reused_model_fn(val_inputs)
with tf.device(tf.DeviceSpec(device_type='CPU')):
if FLAGS.step_size == 1:
vis_input_images = val_inputs['images_0'] * 255.0
vis_output_images = val_inputs['images_1'] * 255.0
vis_pred_images = val_outputs['images_1'] * 255.0
vis_pred_masks = (val_outputs['masks_1'] * (-1) + 1) * 255.0
else:
rep_times = int(np.ceil(32.0 / float(FLAGS.step_size)))
vis_list_1 = []
vis_list_2 = []
vis_list_3 = []
vis_list_4 = []
for j in xrange(rep_times):
for k in xrange(FLAGS.step_size):
vis_input_image = val_inputs['images_0'][j],
vis_output_image = val_inputs['images_%d' % (k + 1)][j]
vis_pred_image = val_outputs['images_%d' % (k + 1)][j]
vis_pred_mask = val_outputs['masks_%d' % (k + 1)][j]
vis_list_1.append(tf.expand_dims(vis_input_image, 0))
vis_list_2.append(tf.expand_dims(vis_output_image, 0))
vis_list_3.append(tf.expand_dims(vis_pred_image, 0))
vis_list_4.append(tf.expand_dims(vis_pred_mask, 0))
vis_list_1 = tf.reshape(
tf.stack(vis_list_1), [
rep_times * FLAGS.step_size, FLAGS.image_size,
FLAGS.image_size, 3
])
vis_list_2 = tf.reshape(
tf.stack(vis_list_2), [
rep_times * FLAGS.step_size, FLAGS.image_size,
FLAGS.image_size, 3
])
vis_list_3 = tf.reshape(
tf.stack(vis_list_3), [
rep_times * FLAGS.step_size, FLAGS.image_size,
FLAGS.image_size, 3
])
vis_list_4 = tf.reshape(
tf.stack(vis_list_4), [
rep_times * FLAGS.step_size, FLAGS.image_size,
FLAGS.image_size, 1
])
vis_input_images = vis_list_1 * 255.0
vis_output_images = vis_list_2 * 255.0
vis_pred_images = vis_list_3 * 255.0
vis_pred_masks = (vis_list_4 * (-1) + 1) * 255.0
write_disk_op = model.write_disk_grid(
global_step=global_step,
summary_freq=FLAGS.save_every,
log_dir=save_image_dir,
input_images=vis_input_images,
output_images=vis_output_images,
pred_images=vis_pred_images,
pred_masks=vis_pred_masks)
with tf.control_dependencies([write_disk_op]):
train_op = tf.identity(train_op)
#############
## init_fn ##
#############
init_fn = model.get_init_fn(['encoder, ' 'rotator', 'decoder'], FLAGS)
##############
## training ##
##############
slim.learning.train(
train_op=train_op,
logdir=train_dir,
init_fn=init_fn,
master=FLAGS.master,
is_chief=(FLAGS.task == 0),
number_of_steps=FLAGS.max_number_of_steps,
saver=saver,
save_summaries_secs=FLAGS.save_summaries_secs,
save_interval_secs=FLAGS.save_interval_secs)
if __name__ == '__main__':
app.run()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Contains training plan for the Im2vox model."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os
import numpy as np
import tensorflow as tf
from tensorflow import app
import model_ptn
flags = tf.app.flags
slim = tf.contrib.slim
flags.DEFINE_string('inp_dir',
'',
'Directory path containing the input data (tfrecords).')
flags.DEFINE_string(
'dataset_name', 'shapenet_chair',
'Dataset name that is to be used for training and evaluation.')
flags.DEFINE_integer('z_dim', 512, '')
flags.DEFINE_integer('f_dim', 64, '')
flags.DEFINE_integer('fc_dim', 1024, '')
flags.DEFINE_integer('num_views', 24, 'Num of viewpoints in the input data.')
flags.DEFINE_integer('image_size', 64,
'Input images dimension (pixels) - width & height.')
flags.DEFINE_integer('vox_size', 32, 'Voxel prediction dimension.')
flags.DEFINE_integer('step_size', 24, 'Steps to take in rotation to fetch viewpoints.')
flags.DEFINE_integer('batch_size', 1, 'Batch size while training.')
flags.DEFINE_float('focal_length', 0.866, 'Focal length parameter used in perspective projection.')
flags.DEFINE_float('focal_range', 1.732, 'Focal length parameter used in perspective projection.')
flags.DEFINE_string('encoder_name', 'ptn_encoder',
'Name of the encoder network being used.')
flags.DEFINE_string('decoder_name', 'ptn_vox_decoder',
'Name of the decoder network being used.')
flags.DEFINE_string('projector_name', 'perspective_projector',
'Name of the projector network being used.')
# Save options
flags.DEFINE_string('checkpoint_dir', '/tmp/ptn_train/',
'Directory path for saving trained models and other data.')
flags.DEFINE_string('model_name', 'ptn_finetune',
'Name of the model used in naming the TF job. Must be different for each run.')
flags.DEFINE_string('init_model', None,
'Checkpoint path of the model to initialize with.')
flags.DEFINE_integer('save_every', 1000,
'Average period of steps after which we save a model.')
# Optimization
flags.DEFINE_float('proj_weight', 10, 'Weighting factor for projection loss.')
flags.DEFINE_float('volume_weight', 0, 'Weighting factor for volume loss.')
flags.DEFINE_float('viewpoint_weight', 1, 'Weighting factor for viewpoint loss.')
flags.DEFINE_float('learning_rate', 0.0001, 'Learning rate.')
flags.DEFINE_float('weight_decay', 0.001, 'Weight decay parameter while training.')
flags.DEFINE_float('clip_gradient_norm', 0, 'Gradient clim norm, leave 0 if no gradient clipping.')
flags.DEFINE_integer('max_number_of_steps', 10000, 'Maximum number of steps for training.')
# Summary
flags.DEFINE_integer('save_summaries_secs', 15, 'Seconds interval for dumping TF summaries.')
flags.DEFINE_integer('save_interval_secs', 60 * 5, 'Seconds interval to save models.')
# Scheduling
flags.DEFINE_string('master', '', 'The address of the tensorflow master')
flags.DEFINE_bool('sync_replicas', False, 'Whether to sync gradients between replicas for optimizer.')
flags.DEFINE_integer('worker_replicas', 1, 'Number of worker replicas (train tasks).')
flags.DEFINE_integer('backup_workers', 0, 'Number of backup workers.')
flags.DEFINE_integer('ps_tasks', 0, 'Number of ps tasks.')
flags.DEFINE_integer('task', 0,
'Task identifier flag to be set for each task running in distributed manner. Task number 0 '
'will be chosen as the chief.')
FLAGS = flags.FLAGS
def main(_):
train_dir = os.path.join(FLAGS.checkpoint_dir, FLAGS.model_name, 'train')
save_image_dir = os.path.join(train_dir, 'images')
if not os.path.exists(train_dir):
os.makedirs(train_dir)
if not os.path.exists(save_image_dir):
os.makedirs(save_image_dir)
g = tf.Graph()
with g.as_default():
with tf.device(tf.train.replica_device_setter(FLAGS.ps_tasks)):
global_step = slim.get_or_create_global_step()
###########
## model ##
###########
model = model_ptn.model_PTN(FLAGS)
##########
## data ##
##########
train_data = model.get_inputs(
FLAGS.inp_dir,
FLAGS.dataset_name,
'train',
FLAGS.batch_size,
FLAGS.image_size,
FLAGS.vox_size,
is_training=True)
inputs = model.preprocess(train_data, FLAGS.step_size)
##############
## model_fn ##
##############
model_fn = model.get_model_fn(
is_training=True, reuse=False, run_projection=True)
outputs = model_fn(inputs)
##################
## train_scopes ##
##################
if FLAGS.init_model:
train_scopes = ['decoder']
init_scopes = ['encoder']
else:
train_scopes = ['encoder', 'decoder']
##########
## loss ##
##########
task_loss = model.get_loss(inputs, outputs)
regularization_loss = model.get_regularization_loss(train_scopes)
loss = task_loss + regularization_loss
###############
## optimizer ##
###############
optimizer = tf.train.AdamOptimizer(FLAGS.learning_rate)
if FLAGS.sync_replicas:
optimizer = tf.train.SyncReplicasOptimizer(
optimizer,
replicas_to_aggregate=FLAGS.workers_replicas - FLAGS.backup_workers,
total_num_replicas=FLAGS.worker_replicas)
##############
## train_op ##
##############
train_op = model.get_train_op_for_scope(loss, optimizer, train_scopes)
###########
## saver ##
###########
saver = tf.train.Saver(max_to_keep=np.minimum(5,
FLAGS.worker_replicas + 1))
if FLAGS.task == 0:
params = FLAGS
params.batch_size = params.num_views
params.step_size = 1
model.set_params(params)
val_data = model.get_inputs(
params.inp_dir,
params.dataset_name,
'val',
params.batch_size,
params.image_size,
params.vox_size,
is_training=False)
val_inputs = model.preprocess(val_data, params.step_size)
# Note: don't compute loss here
reused_model_fn = model.get_model_fn(is_training=False, reuse=True)
val_outputs = reused_model_fn(val_inputs)
with tf.device(tf.DeviceSpec(device_type='CPU')):
vis_input_images = val_inputs['images_1'] * 255.0
vis_gt_projs = (val_outputs['masks_1'] * (-1) + 1) * 255.0
vis_pred_projs = (val_outputs['projs_1'] * (-1) + 1) * 255.0
vis_gt_projs = tf.concat([vis_gt_projs] * 3, axis=3)
vis_pred_projs = tf.concat([vis_pred_projs] * 3, axis=3)
# rescale
new_size = [FLAGS.image_size] * 2
vis_gt_projs = tf.image.resize_nearest_neighbor(
vis_gt_projs, new_size)
vis_pred_projs = tf.image.resize_nearest_neighbor(
vis_pred_projs, new_size)
# flip
# vis_gt_projs = utils.image_flipud(vis_gt_projs)
# vis_pred_projs = utils.image_flipud(vis_pred_projs)
# vis_gt_projs is of shape [batch, height, width, channels]
write_disk_op = model.write_disk_grid(
global_step=global_step,
log_dir=save_image_dir,
input_images=vis_input_images,
gt_projs=vis_gt_projs,
pred_projs=vis_pred_projs,
input_voxels=val_inputs['voxels'],
output_voxels=val_outputs['voxels_1'])
with tf.control_dependencies([write_disk_op]):
train_op = tf.identity(train_op)
#############
## init_fn ##
#############
if FLAGS.init_model:
init_fn = model.get_init_fn(init_scopes)
else:
init_fn = None
##############
## training ##
##############
slim.learning.train(
train_op=train_op,
logdir=train_dir,
init_fn=init_fn,
master=FLAGS.master,
is_chief=(FLAGS.task == 0),
number_of_steps=FLAGS.max_number_of_steps,
saver=saver,
save_summaries_secs=FLAGS.save_summaries_secs,
save_interval_secs=FLAGS.save_interval_secs)
if __name__ == '__main__':
app.run()
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utility functions."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import StringIO
from matplotlib import pylab as p
# axes3d is being used implictly for visualization.
from mpl_toolkits.mplot3d import axes3d as p3 # pylint:disable=unused-import
import numpy as np
from PIL import Image
from skimage import measure
import tensorflow as tf
def save_image(inp_array, image_file):
"""Function that dumps the image to disk."""
inp_array = np.clip(inp_array, 0, 255).astype(np.uint8)
image = Image.fromarray(inp_array)
buf = StringIO.StringIO()
image.save(buf, format='JPEG')
with open(image_file, 'w') as f:
f.write(buf.getvalue())
return None
def image_flipud(images):
"""Function that flip (up-down) the np image."""
quantity = images.get_shape().as_list()[0]
image_list = []
for k in xrange(quantity):
image_list.append(tf.image.flip_up_down(images[k, :, :, :]))
outputs = tf.stack(image_list)
return outputs
def resize_image(inp_array, new_height, new_width):
"""Function that resize the np image."""
inp_array = np.clip(inp_array, 0, 255).astype(np.uint8)
image = Image.fromarray(inp_array)
# Reverse order
image = image.resize((new_width, new_height))
return np.array(image)
def display_voxel(points, vis_size=128):
"""Function to display 3D voxel."""
try:
data = visualize_voxel_spectral(points, vis_size)
except ValueError:
data = visualize_voxel_scatter(points, vis_size)
return data
def visualize_voxel_spectral(points, vis_size=128):
"""Function to visualize voxel (spectral)."""
points = np.rint(points)
points = np.swapaxes(points, 0, 2)
fig = p.figure(figsize=(1, 1), dpi=vis_size)
verts, faces = measure.marching_cubes(points, 0, spacing=(0.1, 0.1, 0.1))
ax = fig.add_subplot(111, projection='3d')
ax.plot_trisurf(
verts[:, 0], verts[:, 1], faces, verts[:, 2], cmap='Spectral_r', lw=0.1)
ax.set_axis_off()
fig.tight_layout(pad=0)
fig.canvas.draw()
data = np.fromstring(
fig.canvas.tostring_rgb(), dtype=np.uint8, sep='').reshape(
vis_size, vis_size, 3)
p.close('all')
return data
def visualize_voxel_scatter(points, vis_size=128):
"""Function to visualize voxel (scatter)."""
points = np.rint(points)
points = np.swapaxes(points, 0, 2)
fig = p.figure(figsize=(1, 1), dpi=vis_size)
ax = fig.add_subplot(111, projection='3d')
x = []
y = []
z = []
(x_dimension, y_dimension, z_dimension) = points.shape
for i in range(x_dimension):
for j in range(y_dimension):
for k in range(z_dimension):
if points[i, j, k]:
x.append(i)
y.append(j)
z.append(k)
ax.scatter3D(x, y, z)
ax.set_axis_off()
fig.tight_layout(pad=0)
fig.canvas.draw()
data = np.fromstring(
fig.canvas.tostring_rgb(), dtype=np.uint8, sep='').reshape(
vis_size, vis_size, 3)
p.close('all')
return data
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment