Unverified Commit 8b641b13 authored by Srihari Humbarwadi's avatar Srihari Humbarwadi Committed by GitHub
Browse files

Merge branch 'tensorflow:master' into panoptic-deeplab

parents 7cffacfe 357fa547
......@@ -12,15 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for teams_experiments."""
from absl.testing import parameterized
import tensorflow as tf
# pylint: disable=unused-import
from official.common import registry_imports
# pylint: enable=unused-import
from official.common import registry_imports # pylint: disable=unused-import
from official.core import config_definitions as cfg
from official.core import exp_factory
......
......@@ -95,10 +95,10 @@ modeling library:
Please cite our paper:
```
@inproceedings{pang2022,
@article{hou2022token,
title={Token Dropping for Efficient BERT Pretraining},
author={Richard Yuanzhe Pang*, Le Hou*, Tianyi Zhou, Yuexin Wu, Xinying Song, Xiaodan Song, Denny Zhou},
year={2022},
organization={Association for Computational Linguistics}
author={Pang, Richard Yuanzhe and Hou, Le and Zhou, Tianyi and Wu, Yuexin and Song, Xinying and Song, Xiaodan and Zhou, Denny},
journal={arXiv preprint arXiv:2203.13240},
year={2022}
}
```
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Configs package definition."""
from official.projects.video_ssl.configs import video_ssl
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Video classification configuration definition."""
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
# pylint: disable=unused-import
from absl.testing import parameterized
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Parser for video and label datasets."""
from typing import Dict, Optional, Tuple
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
import io
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Define losses."""
# Import libraries
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Utils for customed ops for video ssl."""
import functools
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Video ssl linear evaluation task definition."""
from typing import Any, Optional, List, Tuple
from absl import logging
......@@ -20,7 +19,7 @@ import tensorflow as tf
# pylint: disable=unused-import
from official.core import task_factory
from official.projects.video_ssl.configs.google import video_ssl as exp_cfg
from official.projects.video_ssl.configs import video_ssl as exp_cfg
from official.projects.video_ssl.modeling import video_ssl_model
from official.vision.tasks import video_classification
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Video ssl pretrain task definition."""
from absl import logging
import tensorflow as tf
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
import functools
import os
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Training driver."""
from absl import app
......@@ -28,7 +27,7 @@ from official.core import train_utils
from official.modeling import performance
from official.projects.video_ssl.modeling import video_ssl_model
from official.projects.video_ssl.tasks import linear_eval
from official.projects.video_ssl.tasks.google import pretrain
from official.projects.video_ssl.tasks import pretrain
from official.vision import registry_imports
# pylint: disable=unused-import
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Configs package definition."""
from official.projects.vit.configs import image_classification
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Backbones configurations."""
from typing import Optional
......
......@@ -12,39 +12,31 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Image classification configuration definition."""
import os
from typing import List, Optional
import dataclasses
import os
from typing import Optional
from official.core import config_definitions as cfg
from official.core import exp_factory
from official.core import task_factory
from official.modeling import hyperparams
from official.modeling import optimization
from official.projects.vit.configs import backbones
from official.vision.configs import common
from official.vision.configs import image_classification as img_cls_cfg
from official.projects.vit.configs import backbones
from official.vision.tasks import image_classification
# pytype: disable=wrong-keyword-args
DataConfig = img_cls_cfg.DataConfig
@dataclasses.dataclass
class ImageClassificationModel(hyperparams.Config):
class ImageClassificationModel(img_cls_cfg.ImageClassificationModel):
"""The model config."""
num_classes: int = 0
input_size: List[int] = dataclasses.field(default_factory=list)
backbone: backbones.Backbone = backbones.Backbone(
type='vit', vit=backbones.VisionTransformer())
dropout_rate: float = 0.0
norm_activation: common.NormActivation = common.NormActivation(
use_sync_bn=False)
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
add_head_batch_norm: bool = False
kernel_initializer: str = 'random_uniform'
@dataclasses.dataclass
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""VisionTransformer models."""
import tensorflow as tf
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for VIT."""
from absl.testing import parameterized
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""TensorFlow Model Garden Vision training driver, including ViT configs.."""
from absl import app
......
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Backbones configurations."""
import dataclasses
from typing import Optional, Sequence
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment