Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
219f6f06
Commit
219f6f06
authored
Jun 18, 2021
by
Xianzhi Du
Committed by
A. Unique TensorFlower
Jun 18, 2021
Browse files
Internal change
PiperOrigin-RevId: 380237223
parent
c5783656
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
601 additions
and
0 deletions
+601
-0
official/vision/beta/projects/vit/README.md
official/vision/beta/projects/vit/README.md
+12
-0
official/vision/beta/projects/vit/configs/__init__.py
official/vision/beta/projects/vit/configs/__init__.py
+18
-0
official/vision/beta/projects/vit/configs/backbones.py
official/vision/beta/projects/vit/configs/backbones.py
+56
-0
official/vision/beta/projects/vit/configs/image_classification.py
.../vision/beta/projects/vit/configs/image_classification.py
+195
-0
official/vision/beta/projects/vit/modeling/vit.py
official/vision/beta/projects/vit/modeling/vit.py
+249
-0
official/vision/beta/projects/vit/modeling/vit_test.py
official/vision/beta/projects/vit/modeling/vit_test.py
+43
-0
official/vision/beta/projects/vit/train.py
official/vision/beta/projects/vit/train.py
+28
-0
No files found.
official/vision/beta/projects/vit/README.md
0 → 100644
View file @
219f6f06
# Vision Transformer (ViT)
**DISCLAIMER**
: This implementation is still under development. No support will
be provided during the development phase.
[

](https://arxiv.org/abs/2010.11929)
This repository is the implementations of Vision Transformer (ViT) in
TensorFlow 2.
*
Paper title:
[
An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale
](
https://arxiv.org/pdf/2010.11929.pdf
)
.
\ No newline at end of file
official/vision/beta/projects/vit/configs/__init__.py
0 → 100644
View file @
219f6f06
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Configs package definition."""
from
official.vision.beta.projects.vit.configs
import
image_classification
official/vision/beta/projects/vit/configs/backbones.py
0 → 100644
View file @
219f6f06
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Backbones configurations."""
from
typing
import
Optional
import
dataclasses
from
official.modeling
import
hyperparams
@
dataclasses
.
dataclass
class
Transformer
(
hyperparams
.
Config
):
"""Transformer config."""
mlp_dim
:
int
=
1
num_heads
:
int
=
1
num_layers
:
int
=
1
attention_dropout_rate
:
float
=
0.0
dropout_rate
:
float
=
0.1
@
dataclasses
.
dataclass
class
VisionTransformer
(
hyperparams
.
Config
):
"""VisionTransformer config."""
model_name
:
str
=
'vit-b16'
# pylint: disable=line-too-long
classifier
:
str
=
'token'
# 'token' or 'gap'. If set to 'token', an extra classification token is added to sequence.
# pylint: enable=line-too-long
representation_size
:
int
=
0
hidden_size
:
int
=
1
patch_size
:
int
=
16
transformer
:
Transformer
=
Transformer
()
@
dataclasses
.
dataclass
class
Backbone
(
hyperparams
.
OneOfConfig
):
"""Configuration for backbones.
Attributes:
type: 'str', type of backbone be used, one the of fields below.
vit: vit backbone config.
"""
type
:
Optional
[
str
]
=
None
vit
:
VisionTransformer
=
VisionTransformer
()
official/vision/beta/projects/vit/configs/image_classification.py
0 → 100644
View file @
219f6f06
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Image classification configuration definition."""
import
os
from
typing
import
List
,
Optional
import
dataclasses
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.core
import
task_factory
from
official.modeling
import
hyperparams
from
official.modeling
import
optimization
from
official.vision.beta.configs
import
common
from
official.vision.beta.configs
import
image_classification
as
img_cls_cfg
from
official.vision.beta.projects.vit.configs
import
backbones
from
official.vision.beta.tasks
import
image_classification
DataConfig
=
img_cls_cfg
.
DataConfig
@
dataclasses
.
dataclass
class
ImageClassificationModel
(
hyperparams
.
Config
):
"""The model config."""
num_classes
:
int
=
0
input_size
:
List
[
int
]
=
dataclasses
.
field
(
default_factory
=
list
)
backbone
:
backbones
.
Backbone
=
backbones
.
Backbone
(
type
=
'vit'
,
vit
=
backbones
.
VisionTransformer
())
dropout_rate
:
float
=
0.0
norm_activation
:
common
.
NormActivation
=
common
.
NormActivation
(
use_sync_bn
=
False
)
# Adds a BatchNormalization layer pre-GlobalAveragePooling in classification
add_head_batch_norm
:
bool
=
False
@
dataclasses
.
dataclass
class
Losses
(
hyperparams
.
Config
):
one_hot
:
bool
=
True
label_smoothing
:
float
=
0.0
l2_weight_decay
:
float
=
0.0
@
dataclasses
.
dataclass
class
Evaluation
(
hyperparams
.
Config
):
top_k
:
int
=
5
@
dataclasses
.
dataclass
class
ImageClassificationTask
(
cfg
.
TaskConfig
):
"""The task config. Same as the classification task for convnets."""
model
:
ImageClassificationModel
=
ImageClassificationModel
()
train_data
:
DataConfig
=
DataConfig
(
is_training
=
True
)
validation_data
:
DataConfig
=
DataConfig
(
is_training
=
False
)
losses
:
Losses
=
Losses
()
evaluation
:
Evaluation
=
Evaluation
()
init_checkpoint
:
Optional
[
str
]
=
None
init_checkpoint_modules
:
str
=
'all'
# all or backbone
IMAGENET_TRAIN_EXAMPLES
=
1281167
IMAGENET_VAL_EXAMPLES
=
50000
IMAGENET_INPUT_PATH_BASE
=
'imagenet-2012-tfrecord'
# TODO(b/177942984): integrate the experiments to TF-vision.
task_factory
.
register_task_cls
(
ImageClassificationTask
)(
image_classification
.
ImageClassificationTask
)
@
exp_factory
.
register_config_factory
(
'vit_imagenet_pretrain'
)
def
image_classification_imagenet_vit_pretrain
()
->
cfg
.
ExperimentConfig
:
"""Image classification on imagenet with vision transformer."""
train_batch_size
=
4096
eval_batch_size
=
4096
steps_per_epoch
=
IMAGENET_TRAIN_EXAMPLES
//
train_batch_size
config
=
cfg
.
ExperimentConfig
(
task
=
ImageClassificationTask
(
model
=
ImageClassificationModel
(
num_classes
=
1001
,
input_size
=
[
224
,
224
,
3
],
backbone
=
backbones
.
Backbone
(
type
=
'vit'
,
vit
=
backbones
.
VisionTransformer
(
model_name
=
'vit-b16'
,
representation_size
=
768
))),
losses
=
Losses
(
l2_weight_decay
=
0.0
),
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'train*'
),
is_training
=
True
,
global_batch_size
=
train_batch_size
),
validation_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'valid*'
),
is_training
=
False
,
global_batch_size
=
eval_batch_size
)),
trainer
=
cfg
.
TrainerConfig
(
steps_per_loop
=
steps_per_epoch
,
summary_interval
=
steps_per_epoch
,
checkpoint_interval
=
steps_per_epoch
,
train_steps
=
300
*
steps_per_epoch
,
validation_steps
=
IMAGENET_VAL_EXAMPLES
//
eval_batch_size
,
validation_interval
=
steps_per_epoch
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
'adamw'
:
{
'weight_decay_rate'
:
0.3
,
'include_in_weight_decay'
:
r
'.*(kernel|weight):0$'
,
}
},
'learning_rate'
:
{
'type'
:
'cosine'
,
'cosine'
:
{
'initial_learning_rate'
:
0.003
,
'decay_steps'
:
300
*
steps_per_epoch
,
}
},
'warmup'
:
{
'type'
:
'linear'
,
'linear'
:
{
'warmup_steps'
:
10000
,
'warmup_learning_rate'
:
0
}
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
@
exp_factory
.
register_config_factory
(
'vit_imagenet_finetune'
)
def
image_classification_imagenet_vit_finetune
()
->
cfg
.
ExperimentConfig
:
"""Image classification on imagenet with vision transformer."""
train_batch_size
=
512
eval_batch_size
=
512
steps_per_epoch
=
IMAGENET_TRAIN_EXAMPLES
//
train_batch_size
config
=
cfg
.
ExperimentConfig
(
task
=
ImageClassificationTask
(
model
=
ImageClassificationModel
(
num_classes
=
1001
,
input_size
=
[
384
,
384
,
3
],
backbone
=
backbones
.
Backbone
(
type
=
'vit'
,
vit
=
backbones
.
VisionTransformer
(
model_name
=
'vit-b16'
))),
losses
=
Losses
(
l2_weight_decay
=
0.0
),
train_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'train*'
),
is_training
=
True
,
global_batch_size
=
train_batch_size
),
validation_data
=
DataConfig
(
input_path
=
os
.
path
.
join
(
IMAGENET_INPUT_PATH_BASE
,
'valid*'
),
is_training
=
False
,
global_batch_size
=
eval_batch_size
)),
trainer
=
cfg
.
TrainerConfig
(
steps_per_loop
=
steps_per_epoch
,
summary_interval
=
steps_per_epoch
,
checkpoint_interval
=
steps_per_epoch
,
train_steps
=
20000
,
validation_steps
=
IMAGENET_VAL_EXAMPLES
//
eval_batch_size
,
validation_interval
=
steps_per_epoch
,
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'sgd'
,
'sgd'
:
{
'momentum'
:
0.9
,
'global_clipnorm'
:
1.0
,
}
},
'learning_rate'
:
{
'type'
:
'cosine'
,
'cosine'
:
{
'initial_learning_rate'
:
0.003
,
'decay_steps'
:
20000
,
}
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
return
config
official/vision/beta/projects/vit/modeling/vit.py
0 → 100644
View file @
219f6f06
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""VisionTransformer models."""
import
tensorflow
as
tf
from
official.modeling
import
activations
from
official.nlp
import
keras_nlp
from
official.vision.beta.modeling.backbones
import
factory
layers
=
tf
.
keras
.
layers
VIT_SPECS
=
{
'vit-testing'
:
dict
(
hidden_size
=
1
,
patch_size
=
16
,
transformer
=
dict
(
mlp_dim
=
1
,
num_heads
=
1
,
num_layers
=
1
),
),
'vit-b16'
:
dict
(
hidden_size
=
768
,
patch_size
=
16
,
transformer
=
dict
(
mlp_dim
=
3072
,
num_heads
=
12
,
num_layers
=
12
),
),
'vit-b32'
:
dict
(
hidden_size
=
768
,
patch_size
=
32
,
transformer
=
dict
(
mlp_dim
=
3072
,
num_heads
=
12
,
num_layers
=
12
),
),
'vit-l16'
:
dict
(
hidden_size
=
1024
,
patch_size
=
16
,
transformer
=
dict
(
mlp_dim
=
4096
,
num_heads
=
16
,
num_layers
=
24
),
),
'vit-l32'
:
dict
(
hidden_size
=
1024
,
patch_size
=
32
,
transformer
=
dict
(
mlp_dim
=
4096
,
num_heads
=
16
,
num_layers
=
24
),
),
'vit-h14'
:
dict
(
hidden_size
=
1280
,
patch_size
=
14
,
transformer
=
dict
(
mlp_dim
=
5120
,
num_heads
=
16
,
num_layers
=
32
),
),
}
class
AddPositionEmbs
(
tf
.
keras
.
layers
.
Layer
):
"""Adds (optionally learned) positional embeddings to the inputs."""
def
__init__
(
self
,
posemb_init
=
None
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
posemb_init
=
posemb_init
def
build
(
self
,
inputs_shape
):
pos_emb_shape
=
(
1
,
inputs_shape
[
1
],
inputs_shape
[
2
])
self
.
pos_embedding
=
self
.
add_weight
(
'pos_embedding'
,
pos_emb_shape
,
initializer
=
self
.
posemb_init
)
def
call
(
self
,
inputs
,
inputs_positions
=
None
):
# inputs.shape is (batch_size, seq_len, emb_dim).
pos_embedding
=
tf
.
cast
(
self
.
pos_embedding
,
inputs
.
dtype
)
return
inputs
+
pos_embedding
class
TokenLayer
(
tf
.
keras
.
layers
.
Layer
):
"""A simple layer to wrap token parameters."""
def
build
(
self
,
inputs_shape
):
self
.
cls
=
self
.
add_weight
(
'cls'
,
(
1
,
1
,
inputs_shape
[
-
1
]),
initializer
=
'zeros'
)
def
call
(
self
,
inputs
):
cls
=
tf
.
cast
(
self
.
cls
,
inputs
.
dtype
)
cls
=
cls
+
tf
.
zeros_like
(
inputs
[:,
0
:
1
])
# A hacky way to tile.
x
=
tf
.
concat
([
cls
,
inputs
],
axis
=
1
)
return
x
class
Encoder
(
tf
.
keras
.
layers
.
Layer
):
"""Transformer Encoder."""
def
__init__
(
self
,
num_layers
,
mlp_dim
,
num_heads
,
dropout_rate
=
0.1
,
attention_dropout_rate
=
0.1
,
kernel_regularizer
=
None
,
inputs_positions
=
None
,
**
kwargs
):
super
().
__init__
(
**
kwargs
)
self
.
_num_layers
=
num_layers
self
.
_mlp_dim
=
mlp_dim
self
.
_num_heads
=
num_heads
self
.
_dropout_rate
=
dropout_rate
self
.
_attention_dropout_rate
=
attention_dropout_rate
self
.
_kernel_regularizer
=
kernel_regularizer
self
.
_inputs_positions
=
inputs_positions
def
build
(
self
,
input_shape
):
self
.
_pos_embed
=
AddPositionEmbs
(
posemb_init
=
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
0.02
),
name
=
'posembed_input'
)
self
.
_dropout
=
layers
.
Dropout
(
rate
=
self
.
_dropout_rate
)
self
.
_encoder_layers
=
[]
# Set layer norm epsilons to 1e-6 to be consistent with JAX implementation.
# https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.LayerNorm.html
for
_
in
range
(
self
.
_num_layers
):
encoder_layer
=
keras_nlp
.
layers
.
TransformerEncoderBlock
(
inner_activation
=
activations
.
gelu
,
num_attention_heads
=
self
.
_num_heads
,
inner_dim
=
self
.
_mlp_dim
,
output_dropout
=
self
.
_dropout_rate
,
attention_dropout
=
self
.
_attention_dropout_rate
,
kernel_regularizer
=
self
.
_kernel_regularizer
,
norm_first
=
True
,
norm_epsilon
=
1e-6
)
self
.
_encoder_layers
.
append
(
encoder_layer
)
self
.
_norm
=
layers
.
LayerNormalization
(
epsilon
=
1e-6
)
super
().
build
(
input_shape
)
def
call
(
self
,
inputs
,
training
=
None
):
x
=
self
.
_pos_embed
(
inputs
,
inputs_positions
=
self
.
_inputs_positions
)
x
=
self
.
_dropout
(
x
,
training
=
training
)
for
encoder_layer
in
self
.
_encoder_layers
:
x
=
encoder_layer
(
x
,
training
=
training
)
x
=
self
.
_norm
(
x
)
return
x
class
VisionTransformer
(
tf
.
keras
.
Model
):
"""Class to build VisionTransformer family model."""
def
__init__
(
self
,
mlp_dim
=
3072
,
num_heads
=
12
,
num_layers
=
12
,
attention_dropout_rate
=
0.0
,
dropout_rate
=
0.1
,
input_specs
=
layers
.
InputSpec
(
shape
=
[
None
,
None
,
None
,
3
]),
patch_size
=
16
,
hidden_size
=
768
,
representation_size
=
0
,
classifier
=
'token'
,
kernel_regularizer
=
None
):
"""VisionTransformer initialization function."""
inputs
=
tf
.
keras
.
Input
(
shape
=
input_specs
.
shape
[
1
:])
x
=
layers
.
Conv2D
(
filters
=
hidden_size
,
kernel_size
=
patch_size
,
strides
=
patch_size
,
padding
=
'valid'
,
kernel_regularizer
=
kernel_regularizer
)(
inputs
)
if
tf
.
keras
.
backend
.
image_data_format
()
==
'channels_last'
:
rows_axis
,
cols_axis
=
(
1
,
2
)
else
:
rows_axis
,
cols_axis
=
(
2
,
3
)
# The reshape below assumes the data_format is 'channels_last,' so
# transpose to that. Once the data is flattened by the reshape, the
# data_format is irrelevant, so no need to update
# tf.keras.backend.image_data_format.
x
=
tf
.
transpose
(
x
,
perm
=
[
0
,
2
,
3
,
1
])
seq_len
=
(
input_specs
.
shape
[
rows_axis
]
//
patch_size
)
*
(
input_specs
.
shape
[
cols_axis
]
//
patch_size
)
x
=
tf
.
reshape
(
x
,
[
-
1
,
seq_len
,
hidden_size
])
# If we want to add a class token, add it here.
if
classifier
==
'token'
:
x
=
TokenLayer
(
name
=
'cls'
)(
x
)
x
=
Encoder
(
num_layers
=
num_layers
,
mlp_dim
=
mlp_dim
,
num_heads
=
num_heads
,
dropout_rate
=
dropout_rate
,
attention_dropout_rate
=
attention_dropout_rate
,
kernel_regularizer
=
kernel_regularizer
)(
x
)
if
classifier
==
'token'
:
x
=
x
[:,
0
]
elif
classifier
==
'gap'
:
x
=
tf
.
reduce_mean
(
x
,
axis
=
1
)
if
representation_size
:
x
=
tf
.
keras
.
layers
.
Dense
(
representation_size
,
kernel_regularizer
=
kernel_regularizer
,
name
=
'pre_logits'
)(
x
)
x
=
tf
.
nn
.
tanh
(
x
)
else
:
x
=
tf
.
identity
(
x
,
name
=
'pre_logits'
)
endpoints
=
{
'pre_logits'
:
tf
.
reshape
(
x
,
[
-
1
,
1
,
1
,
representation_size
or
hidden_size
])
}
super
(
VisionTransformer
,
self
).
__init__
(
inputs
=
inputs
,
outputs
=
endpoints
)
@
factory
.
register_backbone_builder
(
'vit'
)
def
build_vit
(
input_specs
,
backbone_config
,
norm_activation_config
,
l2_regularizer
=
None
):
"""Build ViT model."""
del
norm_activation_config
backbone_type
=
backbone_config
.
type
backbone_cfg
=
backbone_config
.
get
()
assert
backbone_type
==
'vit'
,
(
f
'Inconsistent backbone type '
f
'
{
backbone_type
}
'
)
backbone_cfg
.
override
(
VIT_SPECS
[
backbone_cfg
.
model_name
])
return
VisionTransformer
(
mlp_dim
=
backbone_cfg
.
transformer
.
mlp_dim
,
num_heads
=
backbone_cfg
.
transformer
.
num_heads
,
num_layers
=
backbone_cfg
.
transformer
.
num_layers
,
attention_dropout_rate
=
backbone_cfg
.
transformer
.
attention_dropout_rate
,
dropout_rate
=
backbone_cfg
.
transformer
.
dropout_rate
,
input_specs
=
input_specs
,
patch_size
=
backbone_cfg
.
patch_size
,
hidden_size
=
backbone_cfg
.
hidden_size
,
representation_size
=
backbone_cfg
.
representation_size
,
classifier
=
backbone_cfg
.
classifier
,
kernel_regularizer
=
l2_regularizer
)
official/vision/beta/projects/vit/modeling/vit_test.py
0 → 100644
View file @
219f6f06
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for VIT."""
from
absl.testing
import
parameterized
import
tensorflow
as
tf
from
official.vision.beta.projects.vit.modeling
import
vit
class
VisionTransformerTest
(
parameterized
.
TestCase
,
tf
.
test
.
TestCase
):
@
parameterized
.
parameters
(
(
224
,
85798656
),
(
256
,
85844736
),
)
def
test_network_creation
(
self
,
input_size
,
params_count
):
"""Test creation of VisionTransformer family models."""
tf
.
keras
.
backend
.
set_image_data_format
(
'channels_last'
)
input_specs
=
tf
.
keras
.
layers
.
InputSpec
(
shape
=
[
2
,
input_size
,
input_size
,
3
])
network
=
vit
.
VisionTransformer
(
input_specs
=
input_specs
)
inputs
=
tf
.
keras
.
Input
(
shape
=
(
input_size
,
input_size
,
3
),
batch_size
=
1
)
_
=
network
(
inputs
)
self
.
assertEqual
(
network
.
count_params
(),
params_count
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/vision/beta/projects/vit/train.py
0 → 100644
View file @
219f6f06
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""TensorFlow Model Garden Vision training driver, including ViT configs.."""
from
absl
import
app
from
official.common
import
flags
as
tfm_flags
from
official.vision.beta
import
train
from
official.vision.beta.projects.vit
import
configs
# pylint: disable=unused-import
from
official.vision.beta.projects.vit.modeling
import
vit
# pylint: disable=unused-import
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
app
.
run
(
train
.
main
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment