Unverified Commit 8b641b13 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'tensorflow:master' into panoptic-deeplab

parents 7cffacfe 357fa547
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
import json import json
import os import os
import random import random
......
# Proper Reuse of Image Classification Features Improves Object Detection
Coming soon
1. CVPR 2022 paper
2. Table of results
\ No newline at end of file
...@@ -20,7 +20,7 @@ from official.core import config_definitions as cfg ...@@ -20,7 +20,7 @@ from official.core import config_definitions as cfg
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
from official.modeling import optimization from official.modeling import optimization
from official.vision.beta.configs import common from official.vision.configs import common
@dataclasses.dataclass @dataclasses.dataclass
......
...@@ -20,7 +20,7 @@ import tensorflow as tf ...@@ -20,7 +20,7 @@ import tensorflow as tf
from official.modeling import tf_utils from official.modeling import tf_utils
from official.projects.basnet.modeling import nn_blocks from official.projects.basnet.modeling import nn_blocks
from official.vision.beta.modeling.backbones import factory from official.vision.modeling.backbones import factory
# Specifications for BASNet encoder. # Specifications for BASNet encoder.
# Each element in the block configuration is in the following format: # Each element in the block configuration is in the following format:
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
import tensorflow as tf import tensorflow as tf
from official.projects.basnet.tasks import basnet from official.projects.basnet.tasks import basnet
from official.vision.beta.serving import semantic_segmentation from official.vision.serving import semantic_segmentation
MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255) MEAN_RGB = (0.485 * 255, 0.456 * 255, 0.406 * 255)
......
...@@ -41,7 +41,7 @@ from absl import flags ...@@ -41,7 +41,7 @@ from absl import flags
from official.core import exp_factory from official.core import exp_factory
from official.modeling import hyperparams from official.modeling import hyperparams
from official.projects.basnet.serving import basnet from official.projects.basnet.serving import basnet
from official.vision.beta.serving import export_saved_model_lib from official.vision.serving import export_saved_model_lib
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
......
...@@ -27,7 +27,7 @@ from official.projects.basnet.evaluation import metrics as basnet_metrics ...@@ -27,7 +27,7 @@ from official.projects.basnet.evaluation import metrics as basnet_metrics
from official.projects.basnet.losses import basnet_losses from official.projects.basnet.losses import basnet_losses
from official.projects.basnet.modeling import basnet_model from official.projects.basnet.modeling import basnet_model
from official.projects.basnet.modeling import refunet from official.projects.basnet.modeling import refunet
from official.vision.beta.dataloaders import segmentation_input from official.vision.dataloaders import segmentation_input
def build_basnet_model( def build_basnet_model(
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""TensorFlow Model Garden Vision training driver.""" """TensorFlow Model Garden Vision training driver."""
from absl import app from absl import app
...@@ -23,7 +22,7 @@ from official.projects.basnet.configs import basnet as basnet_cfg ...@@ -23,7 +22,7 @@ from official.projects.basnet.configs import basnet as basnet_cfg
from official.projects.basnet.modeling import basnet_model from official.projects.basnet.modeling import basnet_model
from official.projects.basnet.modeling import refunet from official.projects.basnet.modeling import refunet
from official.projects.basnet.tasks import basnet as basenet_task from official.projects.basnet.tasks import basnet as basenet_task
from official.vision.beta import train from official.vision import train
if __name__ == '__main__': if __name__ == '__main__':
......
# Contextualized Spatial-Temporal Contrastive Learning with Self-Supervision
(WIP) This repository contains the official implementation of
[Contextualized Spatio-Temporal Contrastive Learning with Self-Supervision](https://arxiv.org/abs/2112.05181)
in TF2.
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Tests for instance_heads.py.""" """Tests for instance_heads.py."""
# Import libraries # Import libraries
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Tests for maskrcnn_model.py.""" """Tests for maskrcnn_model.py."""
# Import libraries # Import libraries
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Detection input and model functions for serving/inference.""" """Detection input and model functions for serving/inference."""
from typing import Dict, Mapping, Text from typing import Dict, Mapping, Text
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Test for image detection export lib.""" """Test for image detection export lib."""
import io import io
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""TensorFlow Model Garden Vision training driver.""" """TensorFlow Model Garden Vision training driver."""
from absl import app from absl import app
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Semantic segmentation configuration definition. """Semantic segmentation configuration definition.
The segmentation model is built using the mobilenet edgetpu v2 backbone and The segmentation model is built using the mobilenet edgetpu v2 backbone and
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Tests for MobileNet.""" """Tests for MobileNet."""
# Import libraries # Import libraries
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Tests for image classification task.""" """Tests for image classification task."""
# pylint: disable=unused-import # pylint: disable=unused-import
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""Tests for semantic segmentation task.""" """Tests for semantic segmentation task."""
# pylint: disable=unused-import # pylint: disable=unused-import
...@@ -20,12 +19,12 @@ from absl.testing import parameterized ...@@ -20,12 +19,12 @@ from absl.testing import parameterized
import orbit import orbit
import tensorflow as tf import tensorflow as tf
from official import vision
from official.core import exp_factory from official.core import exp_factory
from official.modeling import optimization from official.modeling import optimization
from official.projects.edgetpu.vision.configs import semantic_segmentation_config as seg_cfg from official.projects.edgetpu.vision.configs import semantic_segmentation_config as seg_cfg
from official.projects.edgetpu.vision.configs import semantic_segmentation_searched_config as autoseg_cfg from official.projects.edgetpu.vision.configs import semantic_segmentation_searched_config as autoseg_cfg
from official.projects.edgetpu.vision.tasks import semantic_segmentation as img_seg_task from official.projects.edgetpu.vision.tasks import semantic_segmentation as img_seg_task
from official.vision import beta
# Dummy ADE20K TF dataset. # Dummy ADE20K TF dataset.
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# Lint as: python3
"""TensorFlow Model Garden Vision training for MobileNet-EdgeTPU.""" """TensorFlow Model Garden Vision training for MobileNet-EdgeTPU."""
from absl import app from absl import app
......
# Longformer: The Long-Document Transformer
## Modifications from Huggingface's Implementation
All models require a `global_attention_size` specified in the config, setting a
global attention for all first `global_attention_size` tokens in any sentence.
Individual different global attention sizes for sentences are not supported.
This setting allows running on TPUs where tensor sizes have to be determined.
`_get_global_attn_indices` in `longformer_attention.py` contains how the new
global attention indices are specified. Changed all `tf.cond` to if
confiditions, since global attention is specified in the start now.
To load weights from a pre-trained huggingface longformer, run
`utils/convert_pretrained_pytorch_checkpoint_to_tf.py` to create a checkpoint. \
There is also a `utils/longformer_tokenizer_to_tfrecord.py` that transformers
pytorch longformer tokenized data to tf_records.
## Steps to Fine-tune on MNLI
#### Prepare the pre-trained checkpoint
Option 1. Use our saved checkpoint of `allenai/longformer-base-4096` stored in cloud storage
```bash
gsutil cp -r gs://model-garden-ucsd-zihan/longformer-4096 .
```
Option 2. Create it directly
```bash
python3 utils/convert_pretrained_pytorch_checkpoint_to_tf.py
```
#### [Optional] Prepare the input file
```bash
python3 longformer_tokenizer_to_tfrecord.py
```
#### Training
Here, we use the training data of MNLI that were uploaded to the cloud storage, you can replace it with the input files you generated.
```bash
TRAIN_DATA=task.train_data.input_path=gs://model-garden-ucsd-zihan/longformer_allenai_mnli_train.tf_record,task.validation_data.input_path=gs://model-garden-ucsd-zihan/longformer_allenai_mnli_eval.tf_record
INIT_CHECKPOINT=longformer-4096/longformer
PYTHONPATH=/path/to/model/garden \
python3 train.py \
--experiment=longformer/glue \
--config_file=experiments/glue_mnli_allenai.yaml \
--params_override="${TRAIN_DATA},runtime.distribution_strategy=tpu,task.init_checkpoint=${INIT_CHECKPOINT}" \
--tpu=local \
--model_dir=/path/to/outputdir \
--mode=train_and_eval
```
This should take ~ 3 hours to run, and give a performance of ~86.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment