Commit 5d5fb7cc authored by Derek Chow's avatar Derek Chow
Browse files

Enable inference with dynamic batch size in Faster RCNN.

* Adds a util function to compute a mix of dynamic and static shapes
  preferring static when available.
* Uses batch_multiclass_non_max_suppression function in postprocess_rpn
  instead of looping over static batch shape and performing
  multiclass_non_max_suppression.
* Adds a new helper function _unpad_proposals_and_sample_boxclassifier_batch
  to sample from a batch of tensors possibly containing paddings.
* Tests batch inference with various configurations of static shape via
  unittests.
parent 4d641f7f
......@@ -270,6 +270,7 @@ py_library(
deps = [
"//tensorflow",
"//tensorflow_models/object_detection/utils:ops",
"//tensorflow_models/object_detection/utils:shape_utils",
"//tensorflow_models/object_detection/utils:static_shape",
],
)
......
......@@ -29,6 +29,7 @@ few box predictor architectures are shared across many models.
from abc import abstractmethod
import tensorflow as tf
from object_detection.utils import ops
from object_detection.utils import shape_utils
from object_detection.utils import static_shape
slim = tf.contrib.slim
......@@ -524,23 +525,21 @@ class ConvolutionalBoxPredictor(BoxPredictor):
class_predictions_with_background = tf.sigmoid(
class_predictions_with_background)
batch_size = static_shape.get_batch_size(image_features.get_shape())
if batch_size is None:
features_height = static_shape.get_height(image_features.get_shape())
features_width = static_shape.get_width(image_features.get_shape())
flattened_predictions_size = (features_height * features_width *
num_predictions_per_location)
combined_feature_map_shape = shape_utils.combined_static_and_dynamic_shape(
image_features)
box_encodings = tf.reshape(
box_encodings,
[-1, flattened_predictions_size, 1, self._box_code_size])
box_encodings, tf.stack([combined_feature_map_shape[0],
combined_feature_map_shape[1] *
combined_feature_map_shape[2] *
num_predictions_per_location,
1, self._box_code_size]))
class_predictions_with_background = tf.reshape(
class_predictions_with_background,
[-1, flattened_predictions_size, num_class_slots])
else:
box_encodings = tf.reshape(
box_encodings, [batch_size, -1, 1, self._box_code_size])
class_predictions_with_background = tf.reshape(
class_predictions_with_background, [batch_size, -1, num_class_slots])
tf.stack([combined_feature_map_shape[0],
combined_feature_map_shape[1] *
combined_feature_map_shape[2] *
num_predictions_per_location,
num_class_slots]))
return {BOX_ENCODINGS: box_encodings,
CLASS_PREDICTIONS_WITH_BACKGROUND:
class_predictions_with_background}
......@@ -56,6 +56,7 @@ py_library(
"//tensorflow_models/object_detection/core:standard_fields",
"//tensorflow_models/object_detection/core:target_assigner",
"//tensorflow_models/object_detection/utils:ops",
"//tensorflow_models/object_detection/utils:shape_utils",
],
)
......
......@@ -80,6 +80,7 @@ from object_detection.core import post_processing
from object_detection.core import standard_fields as fields
from object_detection.core import target_assigner
from object_detection.utils import ops
from object_detection.utils import shape_utils
slim = tf.contrib.slim
......@@ -765,10 +766,9 @@ class FasterRCNNMetaArch(model.DetectionModel):
A float tensor with shape [A * B, ..., depth] (where the first and last
dimension are statically defined.
"""
inputs_shape = inputs.get_shape().as_list()
flattened_shape = tf.concat([
[inputs_shape[0]*inputs_shape[1]], tf.shape(inputs)[2:-1],
[inputs_shape[-1]]], 0)
combined_shape = shape_utils.combined_static_and_dynamic_shape(inputs)
flattened_shape = tf.stack([combined_shape[0] * combined_shape[1]] +
combined_shape[2:])
return tf.reshape(inputs, flattened_shape)
def postprocess(self, prediction_dict):
......@@ -866,52 +866,128 @@ class FasterRCNNMetaArch(model.DetectionModel):
representing the number of proposals predicted for each image in
the batch.
"""
rpn_box_encodings_batch = tf.expand_dims(rpn_box_encodings_batch, axis=2)
rpn_encodings_shape = shape_utils.combined_static_and_dynamic_shape(
rpn_box_encodings_batch)
tiled_anchor_boxes = tf.tile(
tf.expand_dims(anchors, 0), [rpn_encodings_shape[0], 1, 1])
proposal_boxes = self._batch_decode_boxes(rpn_box_encodings_batch,
tiled_anchor_boxes)
proposal_boxes = tf.squeeze(proposal_boxes, axis=2)
rpn_objectness_softmax_without_background = tf.nn.softmax(
rpn_objectness_predictions_with_background_batch)[:, :, 1]
clip_window = tf.to_float(tf.stack([0, 0, image_shape[1], image_shape[2]]))
if self._is_training:
(groundtruth_boxlists, groundtruth_classes_with_background_list
) = self._format_groundtruth_data(image_shape)
proposal_boxes_list = []
proposal_scores_list = []
num_proposals_list = []
for (batch_index,
(rpn_box_encodings,
rpn_objectness_predictions_with_background)) in enumerate(zip(
tf.unstack(rpn_box_encodings_batch),
tf.unstack(rpn_objectness_predictions_with_background_batch))):
decoded_boxes = self._box_coder.decode(
rpn_box_encodings, box_list.BoxList(anchors))
objectness_scores = tf.unstack(
tf.nn.softmax(rpn_objectness_predictions_with_background), axis=1)[1]
proposal_boxlist = post_processing.multiclass_non_max_suppression(
tf.expand_dims(decoded_boxes.get(), 1),
tf.expand_dims(objectness_scores, 1),
(proposal_boxes, proposal_scores, _, _,
num_proposals) = post_processing.batch_multiclass_non_max_suppression(
tf.expand_dims(proposal_boxes, axis=2),
tf.expand_dims(rpn_objectness_softmax_without_background,
axis=2),
self._first_stage_nms_score_threshold,
self._first_stage_nms_iou_threshold, self._first_stage_max_proposals,
self._first_stage_nms_iou_threshold,
self._first_stage_max_proposals,
self._first_stage_max_proposals,
clip_window=clip_window)
if self._is_training:
proposal_boxlist.set(tf.stop_gradient(proposal_boxlist.get()))
proposal_boxes = tf.stop_gradient(proposal_boxes)
if not self._hard_example_miner:
proposal_boxlist = self._sample_box_classifier_minibatch(
proposal_boxlist, groundtruth_boxlists[batch_index],
groundtruth_classes_with_background_list[batch_index])
normalized_proposals = box_list_ops.to_normalized_coordinates(
proposal_boxlist, image_shape[1], image_shape[2],
check_range=False)
# pad proposals to max_num_proposals
padded_proposals = box_list_ops.pad_or_clip_box_list(
normalized_proposals, num_boxes=self.max_num_proposals)
proposal_boxes_list.append(padded_proposals.get())
proposal_scores_list.append(
padded_proposals.get_field(fields.BoxListFields.scores))
num_proposals_list.append(tf.minimum(normalized_proposals.num_boxes(),
self.max_num_proposals))
return (tf.stack(proposal_boxes_list), tf.stack(proposal_scores_list),
tf.stack(num_proposals_list))
(groundtruth_boxlists, groundtruth_classes_with_background_list,
) = self._format_groundtruth_data(image_shape)
(proposal_boxes, proposal_scores,
num_proposals) = self._unpad_proposals_and_sample_box_classifier_batch(
proposal_boxes, proposal_scores, num_proposals,
groundtruth_boxlists, groundtruth_classes_with_background_list)
# normalize proposal boxes
proposal_boxes_reshaped = tf.reshape(proposal_boxes, [-1, 4])
normalized_proposal_boxes_reshaped = box_list_ops.to_normalized_coordinates(
box_list.BoxList(proposal_boxes_reshaped),
image_shape[1], image_shape[2], check_range=False).get()
proposal_boxes = tf.reshape(normalized_proposal_boxes_reshaped,
[-1, proposal_boxes.shape[1].value, 4])
return proposal_boxes, proposal_scores, num_proposals
def _unpad_proposals_and_sample_box_classifier_batch(
self,
proposal_boxes,
proposal_scores,
num_proposals,
groundtruth_boxlists,
groundtruth_classes_with_background_list):
"""Unpads proposals and samples a minibatch for second stage.
Args:
proposal_boxes: A float tensor with shape
[batch_size, num_proposals, 4] representing the (potentially zero
padded) proposal boxes for all images in the batch. These boxes are
represented as normalized coordinates.
proposal_scores: A float tensor with shape
[batch_size, num_proposals] representing the (potentially zero
padded) proposal objectness scores for all images in the batch.
num_proposals: A Tensor of type `int32`. A 1-D tensor of shape [batch]
representing the number of proposals predicted for each image in
the batch.
groundtruth_boxlists: A list of BoxLists containing (absolute) coordinates
of the groundtruth boxes.
groundtruth_classes_with_background_list: A list of 2-D one-hot
(or k-hot) tensors of shape [num_boxes, num_classes+1] containing the
class targets with the 0th index assumed to map to the background class.
Returns:
proposal_boxes: A float tensor with shape
[batch_size, second_stage_batch_size, 4] representing the (potentially
zero padded) proposal boxes for all images in the batch. These boxes
are represented as normalized coordinates.
proposal_scores: A float tensor with shape
[batch_size, second_stage_batch_size] representing the (potentially zero
padded) proposal objectness scores for all images in the batch.
num_proposals: A Tensor of type `int32`. A 1-D tensor of shape [batch]
representing the number of proposals predicted for each image in
the batch.
"""
single_image_proposal_box_sample = []
single_image_proposal_score_sample = []
single_image_num_proposals_sample = []
for (single_image_proposal_boxes,
single_image_proposal_scores,
single_image_num_proposals,
single_image_groundtruth_boxlist,
single_image_groundtruth_classes_with_background) in zip(
tf.unstack(proposal_boxes),
tf.unstack(proposal_scores),
tf.unstack(num_proposals),
groundtruth_boxlists,
groundtruth_classes_with_background_list):
static_shape = single_image_proposal_boxes.get_shape()
sliced_static_shape = tf.TensorShape([tf.Dimension(None),
static_shape.dims[-1]])
single_image_proposal_boxes = tf.slice(
single_image_proposal_boxes,
[0, 0],
[single_image_num_proposals, -1])
single_image_proposal_boxes.set_shape(sliced_static_shape)
single_image_proposal_scores = tf.slice(single_image_proposal_scores,
[0],
[single_image_num_proposals])
single_image_boxlist = box_list.BoxList(single_image_proposal_boxes)
single_image_boxlist.add_field(fields.BoxListFields.scores,
single_image_proposal_scores)
sampled_boxlist = self._sample_box_classifier_minibatch(
single_image_boxlist,
single_image_groundtruth_boxlist,
single_image_groundtruth_classes_with_background)
sampled_padded_boxlist = box_list_ops.pad_or_clip_box_list(
sampled_boxlist,
num_boxes=self._second_stage_batch_size)
single_image_num_proposals_sample.append(tf.minimum(
sampled_boxlist.num_boxes(),
self._second_stage_batch_size))
bb = sampled_padded_boxlist.get()
single_image_proposal_box_sample.append(bb)
single_image_proposal_score_sample.append(
sampled_padded_boxlist.get_field(fields.BoxListFields.scores))
return (tf.stack(single_image_proposal_box_sample),
tf.stack(single_image_proposal_score_sample),
tf.stack(single_image_num_proposals_sample))
def _format_groundtruth_data(self, image_shape):
"""Helper function for preparing groundtruth data for target assignment.
......@@ -1065,7 +1141,7 @@ class FasterRCNNMetaArch(model.DetectionModel):
class_predictions_with_background,
[-1, self.max_num_proposals, self.num_classes + 1]
)
refined_decoded_boxes_batch = self._batch_decode_refined_boxes(
refined_decoded_boxes_batch = self._batch_decode_boxes(
refined_box_encodings_batch, proposal_boxes)
class_predictions_with_background_batch = (
self._second_stage_score_conversion_fn(
......@@ -1102,7 +1178,7 @@ class FasterRCNNMetaArch(model.DetectionModel):
tf.greater_equal(detections['detection_masks'], mask_threshold))
return detections
def _batch_decode_refined_boxes(self, refined_box_encodings, proposal_boxes):
def _batch_decode_boxes(self, box_encodings, anchor_boxes):
"""Decode tensor of refined box encodings.
Args:
......@@ -1117,15 +1193,33 @@ class FasterRCNNMetaArch(model.DetectionModel):
float tensor representing (padded) refined bounding box predictions
(for each image in batch, proposal and class).
"""
tiled_proposal_boxes = tf.tile(
tf.expand_dims(proposal_boxes, 2), [1, 1, self.num_classes, 1])
tiled_proposals_boxlist = box_list.BoxList(
tf.reshape(tiled_proposal_boxes, [-1, 4]))
"""Decodes box encodings with respect to the anchor boxes.
Args:
box_encodings: a 4-D tensor with shape
[batch_size, num_anchors, num_classes, self._box_coder.code_size]
representing box encodings.
anchor_boxes: [batch_size, num_anchors, 4] representing
decoded bounding boxes.
Returns:
decoded_boxes: a [batch_size, num_anchors, num_classes, 4]
float tensor representing bounding box predictions
(for each image in batch, proposal and class).
"""
combined_shape = shape_utils.combined_static_and_dynamic_shape(
box_encodings)
num_classes = combined_shape[2]
tiled_anchor_boxes = tf.tile(
tf.expand_dims(anchor_boxes, 2), [1, 1, num_classes, 1])
tiled_anchors_boxlist = box_list.BoxList(
tf.reshape(tiled_anchor_boxes, [-1, 4]))
decoded_boxes = self._box_coder.decode(
tf.reshape(refined_box_encodings, [-1, self._box_coder.code_size]),
tiled_proposals_boxlist)
tf.reshape(box_encodings, [-1, self._box_coder.code_size]),
tiled_anchors_boxlist)
return tf.reshape(decoded_boxes.get(),
[-1, self.max_num_proposals, self.num_classes, 4])
tf.stack([combined_shape[0], combined_shape[1],
num_classes, 4]))
def loss(self, prediction_dict, scope=None):
"""Compute scalar loss tensors given prediction tensors.
......@@ -1439,4 +1533,3 @@ class FasterRCNNMetaArch(model.DetectionModel):
include_patterns=[self.first_stage_feature_extractor_scope,
self.second_stage_feature_extractor_scope])
return {var.op.name: var for var in feature_extractor_variables}
......@@ -226,61 +226,47 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
return self._get_model(self._get_second_stage_box_predictor(
num_classes=num_classes, is_training=is_training), **common_kwargs)
def test_predict_gives_correct_shapes_in_inference_mode_first_stage_only(
def test_predict_correct_shapes_in_inference_mode_both_stages(
self):
test_graph = tf.Graph()
with test_graph.as_default():
model = self._build_model(
is_training=False, first_stage_only=True, second_stage_batch_size=2)
batch_size = 2
height = 10
width = 12
input_image_shape = (batch_size, height, width, 3)
preprocessed_inputs = tf.placeholder(dtype=tf.float32,
shape=(batch_size, None, None, 3))
prediction_dict = model.predict(preprocessed_inputs)
# In inference mode, anchors are clipped to the image window, but not
# pruned. Since MockFasterRCNN.extract_proposal_features returns a
# tensor with the same shape as its input, the expected number of anchors
# is height * width * the number of anchors per location (i.e. 3x3).
expected_num_anchors = height * width * 3 * 3
expected_output_keys = set([
'rpn_box_predictor_features', 'rpn_features_to_crop', 'image_shape',
'rpn_box_encodings', 'rpn_objectness_predictions_with_background',
'anchors'])
expected_output_shapes = {
'rpn_box_predictor_features': (batch_size, height, width, 512),
'rpn_features_to_crop': (batch_size, height, width, 3),
'rpn_box_encodings': (batch_size, expected_num_anchors, 4),
image_size = 10
input_shapes = [(batch_size, image_size, image_size, 3),
(None, image_size, image_size, 3),
(batch_size, None, None, 3),
(None, None, None, 3)]
expected_num_anchors = image_size * image_size * 3 * 3
expected_shapes = {
'rpn_box_predictor_features':
(2, image_size, image_size, 512),
'rpn_features_to_crop': (2, image_size, image_size, 3),
'image_shape': (4,),
'rpn_box_encodings': (2, expected_num_anchors, 4),
'rpn_objectness_predictions_with_background':
(batch_size, expected_num_anchors, 2),
'anchors': (expected_num_anchors, 4)
(2, expected_num_anchors, 2),
'anchors': (expected_num_anchors, 4),
'refined_box_encodings': (2 * 8, 2, 4),
'class_predictions_with_background': (2 * 8, 2 + 1),
'num_proposals': (2,),
'proposal_boxes': (2, 8, 4),
}
for input_shape in input_shapes:
test_graph = tf.Graph()
with test_graph.as_default():
model = self._build_model(
is_training=False, first_stage_only=False,
second_stage_batch_size=2)
preprocessed_inputs = tf.placeholder(tf.float32, shape=input_shape)
result_tensor_dict = model.predict(preprocessed_inputs)
init_op = tf.global_variables_initializer()
with self.test_session() as sess:
with self.test_session(graph=test_graph) as sess:
sess.run(init_op)
prediction_out = sess.run(prediction_dict,
feed_dict={
tensor_dict_out = sess.run(result_tensor_dict, feed_dict={
preprocessed_inputs:
np.zeros(input_image_shape)
})
self.assertEqual(set(prediction_out.keys()), expected_output_keys)
self.assertAllEqual(prediction_out['image_shape'], input_image_shape)
for output_key, expected_shape in expected_output_shapes.items():
self.assertAllEqual(prediction_out[output_key].shape, expected_shape)
# Check that anchors are clipped to window.
anchors = prediction_out['anchors']
self.assertTrue(np.all(np.greater_equal(anchors, 0)))
self.assertTrue(np.all(np.less_equal(anchors[:, 0], height)))
self.assertTrue(np.all(np.less_equal(anchors[:, 1], width)))
self.assertTrue(np.all(np.less_equal(anchors[:, 2], height)))
self.assertTrue(np.all(np.less_equal(anchors[:, 3], width)))
np.zeros((batch_size, image_size, image_size, 3))})
self.assertEqual(set(tensor_dict_out.keys()),
set(expected_shapes.keys()))
for key in expected_shapes:
self.assertAllEqual(tensor_dict_out[key].shape, expected_shapes[key])
def test_predict_gives_valid_anchors_in_training_mode_first_stage_only(self):
test_graph = tf.Graph()
......@@ -535,35 +521,67 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
expected_num_proposals)
def test_postprocess_second_stage_only_inference_mode(self):
model = self._build_model(
is_training=False, first_stage_only=False, second_stage_batch_size=6)
num_proposals_shapes = [(2), (None)]
refined_box_encodings_shapes = [(16, 2, 4), (None, 2, 4)]
class_predictions_with_background_shapes = [(16, 3), (None, 3)]
proposal_boxes_shapes = [(2, 8, 4), (None, 8, 4)]
batch_size = 2
image_shape = np.array((2, 36, 48, 3), dtype=np.int32)
for (num_proposals_shape, refined_box_encoding_shape,
class_predictions_with_background_shape,
proposal_boxes_shape) in zip(num_proposals_shapes,
refined_box_encodings_shapes,
class_predictions_with_background_shapes,
proposal_boxes_shapes):
tf_graph = tf.Graph()
with tf_graph.as_default():
model = self._build_model(
is_training=False, first_stage_only=False,
second_stage_batch_size=6)
total_num_padded_proposals = batch_size * model.max_num_proposals
proposal_boxes = tf.constant(
proposal_boxes = np.array(
[[[1, 1, 2, 3],
[0, 0, 1, 1],
[.5, .5, .6, .6],
4*[0], 4*[0], 4*[0], 4*[0], 4*[0]],
[[2, 3, 6, 8],
[1, 2, 5, 3],
4*[0], 4*[0], 4*[0], 4*[0], 4*[0], 4*[0]]], dtype=tf.float32)
num_proposals = tf.constant([3, 2], dtype=tf.int32)
refined_box_encodings = tf.zeros(
[total_num_padded_proposals, model.num_classes, 4], dtype=tf.float32)
class_predictions_with_background = tf.ones(
[total_num_padded_proposals, model.num_classes+1], dtype=tf.float32)
image_shape = tf.constant([batch_size, 36, 48, 3], dtype=tf.int32)
4*[0], 4*[0], 4*[0], 4*[0], 4*[0], 4*[0]]])
num_proposals = np.array([3, 2], dtype=np.int32)
refined_box_encodings = np.zeros(
[total_num_padded_proposals, model.num_classes, 4])
class_predictions_with_background = np.ones(
[total_num_padded_proposals, model.num_classes+1])
num_proposals_placeholder = tf.placeholder(tf.int32,
shape=num_proposals_shape)
refined_box_encodings_placeholder = tf.placeholder(
tf.float32, shape=refined_box_encoding_shape)
class_predictions_with_background_placeholder = tf.placeholder(
tf.float32, shape=class_predictions_with_background_shape)
proposal_boxes_placeholder = tf.placeholder(
tf.float32, shape=proposal_boxes_shape)
image_shape_placeholder = tf.placeholder(tf.int32, shape=(4))
detections = model.postprocess({
'refined_box_encodings': refined_box_encodings,
'class_predictions_with_background': class_predictions_with_background,
'num_proposals': num_proposals,
'proposal_boxes': proposal_boxes,
'image_shape': image_shape
'refined_box_encodings': refined_box_encodings_placeholder,
'class_predictions_with_background':
class_predictions_with_background_placeholder,
'num_proposals': num_proposals_placeholder,
'proposal_boxes': proposal_boxes_placeholder,
'image_shape': image_shape_placeholder,
})
with self.test_session(graph=tf_graph) as sess:
detections_out = sess.run(
detections,
feed_dict={
refined_box_encodings_placeholder: refined_box_encodings,
class_predictions_with_background_placeholder:
class_predictions_with_background,
num_proposals_placeholder: num_proposals,
proposal_boxes_placeholder: proposal_boxes,
image_shape_placeholder: image_shape
})
with self.test_session() as sess:
detections_out = sess.run(detections)
self.assertAllEqual(detections_out['detection_boxes'].shape, [2, 5, 4])
self.assertAllClose(detections_out['detection_scores'],
[[1, 1, 1, 1, 1], [1, 1, 1, 1, 0]])
......@@ -571,6 +589,17 @@ class FasterRCNNMetaArchTestBase(tf.test.TestCase):
[[0, 0, 0, 1, 1], [0, 0, 1, 1, 0]])
self.assertAllClose(detections_out['num_detections'], [5, 4])
def test_preprocess_preserves_input_shapes(self):
image_shapes = [(3, None, None, 3),
(None, 10, 10, 3),
(None, None, None, 3)]
for image_shape in image_shapes:
model = self._build_model(
is_training=False, first_stage_only=False, second_stage_batch_size=6)
image_placeholder = tf.placeholder(tf.float32, shape=image_shape)
preprocessed_inputs = model.preprocess(image_placeholder)
self.assertAllEqual(preprocessed_inputs.shape.as_list(), image_shape)
def test_loss_first_stage_only_mode(self):
model = self._build_model(
is_training=True, first_stage_only=True, second_stage_batch_size=6)
......
......@@ -111,3 +111,26 @@ def pad_or_clip_tensor(t, length):
if not _is_tensor(length):
processed_t = _set_dim_0(processed_t, length)
return processed_t
def combined_static_and_dynamic_shape(tensor):
"""Returns a list containing static and dynamic values for the dimensions.
Returns a list of static and dynamic values for shape dimensions. This is
useful to preserve static shapes when available in reshape operation.
Args:
tensor: A tensor of any type.
Returns:
A list of size tensor.shape.ndims containing integers or a scalar tensor.
"""
static_shape = tensor.shape.as_list()
dynamic_shape = tf.shape(tensor)
combined_shape = []
for index, dim in enumerate(static_shape):
if dim is not None:
combined_shape.append(dim)
else:
combined_shape.append(dynamic_shape[index])
return combined_shape
......@@ -115,6 +115,13 @@ class UtilTest(tf.test.TestCase):
self.assertAllEqual([1, 2], tt3_result)
self.assertAllClose([[0.1, 0.2], [0.2, 0.4]], tt4_result)
def test_combines_static_dynamic_shape(self):
tensor = tf.placeholder(tf.float32, shape=(None, 2, 3))
combined_shape = shape_utils.combined_static_and_dynamic_shape(
tensor)
self.assertTrue(tf.contrib.framework.is_tensor(combined_shape[0]))
self.assertListEqual(combined_shape[1:], [2, 3])
if __name__ == '__main__':
tf.test.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment