Unverified Commit c127d527 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'panoptic-segmentation' into panoptic-deeplab-modeling

parents 78657911 457bcb85
......@@ -403,7 +403,7 @@ message CenterNet {
// Mask prediction support using DeepMAC. See https://arxiv.org/abs/2104.00613
// Next ID 24
// Next ID 25
message DeepMACMaskEstimation {
// The loss used for penalizing mask predictions.
optional ClassificationLoss classification_loss = 1;
......@@ -485,6 +485,14 @@ message CenterNet {
optional int32 color_consistency_warmup_start = 23 [default=0];
// DeepMAC has been refactored to process the entire batch at once,
// instead of the previous (simple) approach of processing one sample at
// a time. Because of this, the memory consumption has increased and
// it's crucial to only feed the mask head the last stage outputs
// from the hourglass. Doing so halves the memory requirement of the
// mask head and does not cause a drop in evaluation metrics.
optional bool use_only_last_stage = 24 [default=false];
}
optional DeepMACMaskEstimation deepmac_mask_estimation = 14;
......
......@@ -76,3 +76,29 @@ py_strict_library(
"//tf_ops:sequence_string_projection_op_v2_py", # sequence projection
],
)
py_strict_library(
name = "misc_layers",
srcs = ["misc_layers.py"],
srcs_version = "PY3",
deps = [
# package tensorflow
"//layers:base_layers", # sequence projection
"//layers:dense_layers", # sequence projection
"//layers:quantization_layers", # sequence projection
],
)
py_strict_library(
name = "qrnn_layers",
srcs = ["qrnn_layers.py"],
srcs_version = "PY3",
deps = [
":base_layers",
":conv_layers",
":dense_layers",
":quantization_layers",
# package tensorflow
"//tf_ops:tf_custom_ops_py", # sequence projection
],
)
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -90,6 +90,7 @@ py_binary(
main = "run_tflite.py",
python_version = "PY3",
deps = [
":sgnn_projection_op_resolver",
# Expect numpy installed
# package TFLite flex delegate
# package TFLite interpreter
......
......@@ -43,8 +43,6 @@ Hparams = collections.namedtuple(
def preprocess(text):
"""Normalize the text, and return tokens."""
assert len(text.get_shape().as_list()) == 2
assert text.get_shape().as_list()[-1] == 1
text = tf.reshape(text, [-1])
text = tf_text.case_fold_utf8(text)
tokenizer = tflite_text_api.WhitespaceTokenizer()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment