Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
d6668868
Commit
d6668868
authored
Mar 14, 2022
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Mar 14, 2022
Browse files
Open source labse project.
PiperOrigin-RevId: 434592648
parent
4bb36073
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
578 additions
and
0 deletions
+578
-0
official/projects/labse/README.md
official/projects/labse/README.md
+111
-0
official/projects/labse/config_labse.py
official/projects/labse/config_labse.py
+68
-0
official/projects/labse/experiments/labse_base.yaml
official/projects/labse/experiments/labse_base.yaml
+85
-0
official/projects/labse/experiments/labse_bert_base.yaml
official/projects/labse/experiments/labse_bert_base.yaml
+15
-0
official/projects/labse/export_tfhub.py
official/projects/labse/export_tfhub.py
+161
-0
official/projects/labse/export_tfhub_test.py
official/projects/labse/export_tfhub_test.py
+111
-0
official/projects/labse/train.py
official/projects/labse/train.py
+27
-0
No files found.
official/projects/labse/README.md
0 → 100644
View file @
d6668868
# Language-agnostic BERT Sentence Embedding
The repository contains the implementation and experiment definition of
`LaBSE`
,
[
Language-agnostic BERT Sentence Embedding
](
https://arxiv.org/pdf/2007.01852.pdf
)
.
The implementation is provided by the paper author, Yinfei Yang. Note that,
the cross-accelerator batch softmax is not implemented by the author, so the
implementation does not fully reproduce the paper yet.
Due to the data policy, the authors are not able to release the pre-training and
fine-tuning data for
`LaBSE`
training.
### Requirements
The starter code requires Tensorflow. If you haven't installed it yet, follow
the instructions on
[
tensorflow.org
][
1
]
.
This code has been tested with Tensorflow 2.8.0. Going forward,
we will continue to target the latest released version of Tensorflow.
Please verify that you have Python 3.7+ and Tensorflow 2.8.0 or higher
installed by running the following commands:
```
sh
python
--version
python
-c
'import tensorflow as tf; print(tf.__version__)'
```
Refer to the
[
instructions here
][
2
]
for using the model in this repo. Make sure to add the models folder to your
Python path.
[
1
]:
https://www.tensorflow.org/install/
[
2
]:
https://github.com/tensorflow/models/tree/master/official#running-the-models
## Data
The pre-training data should be multi-lingual and the format is the same as BERT
pre-training.
The fine-tuning data follows the format as below:
```
text
{ # (tensorflow.Example)
features: {
feature: {
key : "src_raw"
value: {
bytes_list: {
value: [ "Foo. " ]
}
}
}
feature: {
key : "tgt_raw"
value: {
bytes_list: {
value: [ "Bar. " ]
}
}
}
}
}
```
## Train using the config file.
After you generated your pretraining data, run the following command to start
pretraining:
```
bash
TPU
=
local
VOCAB
=
???
INIT_CHECKPOINT
=
???
PARAMS
=
"task.train_data.input_data=/path/to/train/data"
PARAMS
=
"
${
PARAMS
}
,task.train_data.vocab_file=
${
VOCAB
}
"
PARAMS
=
"
${
PARAMS
}
,task.validation_data.input_path=/path/to/validation/data"
PARAMS
=
"
${
PARAMS
}
,task.validation_data.vocab_file=
${
VOCAB
}
"
PARAMS
=
"
${
PARAMS
}
,task.init_checkpoint=
${
INIT_CHECKPOINT
}
"
PARAMS
=
"
${
PARAMS
}
,runtime.distribution_strategy=tpu"
python3 train.py
\
--experiment
=
labse/train
\
--config_file
=
./experiments/labse_bert_base.yaml
\
--config_file
=
./experiments/labse_base.yaml
\
--params_override
=
${
PARAMS
}
\
--tpu
=
${
TPU
}
\
--model_dir
=
/folder/to/hold/logs/and/models/
\
--mode
=
train_and_eval
```
## Implementation
We implement the encoder and layers using
`tf.keras`
APIs in NLP
modeling library:
*
[
dual_encoder.py
](
https://github.com/tensorflow/models/blob/master/official/nlp/tasks/dual_encoder.py
)
contains the dual-encoder task used for labse training.
*
[
config_labse.py
](
https://github.com/tensorflow/models/blob/master/official/projects/labse/config_labse.py
)
registers the labse training experiment.
*
[
train.py
](
https://github.com/tensorflow/models/blob/master/official/projects/labse/train.py
)
is the program entry.
## Pre-trained model through TF-HUB
If you are looking for pre-trained models, please check out:
https://tfhub.dev/google/LaBSE/2.
The hub
`SavedModel`
s are exported through the
`export_tfhub.py`
in
this repository.
official/projects/labse/config_labse.py
0 → 100644
View file @
d6668868
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=g-doc-return-or-yield,line-too-long
"""LaBSE configurations."""
import
dataclasses
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.modeling
import
optimization
from
official.nlp.data
import
dual_encoder_dataloader
from
official.nlp.tasks
import
dual_encoder
AdamWeightDecay
=
optimization
.
AdamWeightDecayConfig
PolynomialLr
=
optimization
.
PolynomialLrConfig
PolynomialWarmupConfig
=
optimization
.
PolynomialWarmupConfig
@
dataclasses
.
dataclass
class
LaBSEOptimizationConfig
(
optimization
.
OptimizationConfig
):
"""Bert optimization config."""
optimizer
:
optimization
.
OptimizerConfig
=
optimization
.
OptimizerConfig
(
type
=
"adamw"
,
adamw
=
AdamWeightDecay
())
learning_rate
:
optimization
.
LrConfig
=
optimization
.
LrConfig
(
type
=
"polynomial"
,
polynomial
=
PolynomialLr
(
initial_learning_rate
=
1e-4
,
decay_steps
=
1000000
,
end_learning_rate
=
0.0
))
warmup
:
optimization
.
WarmupConfig
=
optimization
.
WarmupConfig
(
type
=
"polynomial"
,
polynomial
=
PolynomialWarmupConfig
(
warmup_steps
=
10000
))
@
exp_factory
.
register_config_factory
(
"labse/train"
)
def
labse_train
()
->
cfg
.
ExperimentConfig
:
r
"""Language-agnostic bert sentence embedding.
*Note*: this experiment does not use cross-accelerator global softmax so it
does not reproduce the exact LABSE training.
"""
config
=
cfg
.
ExperimentConfig
(
task
=
dual_encoder
.
DualEncoderConfig
(
train_data
=
dual_encoder_dataloader
.
DualEncoderDataConfig
(),
validation_data
=
dual_encoder_dataloader
.
DualEncoderDataConfig
(
is_training
=
False
,
drop_remainder
=
False
)),
trainer
=
cfg
.
TrainerConfig
(
optimizer_config
=
LaBSEOptimizationConfig
(
learning_rate
=
optimization
.
LrConfig
(
type
=
"polynomial"
,
polynomial
=
PolynomialLr
(
initial_learning_rate
=
3e-5
,
end_learning_rate
=
0.0
)),
warmup
=
optimization
.
WarmupConfig
(
type
=
"polynomial"
,
polynomial
=
PolynomialWarmupConfig
()))),
restrictions
=
[
"task.train_data.is_training != None"
,
"task.validation_data.is_training != None"
])
return
config
official/projects/labse/experiments/labse_base.yaml
0 → 100644
View file @
d6668868
task
:
hub_module_url
:
'
'
model
:
bidirectional
:
true
max_sequence_length
:
32
logit_scale
:
100
logit_margin
:
0.3
init_checkpoint
:
'
the
pre-trained
BERT
checkpoint
using
the
labse
vocab.'
train_data
:
drop_remainder
:
true
global_batch_size
:
4096
input_path
:
'
the
path
to
train
partition'
left_text_fields
:
[
'
src_raw'
]
right_text_fields
:
[
'
tgt_raw'
]
vocab_file
:
'
the
path
to
vocab.txt'
lower_case
:
false
is_training
:
true
seq_length
:
32
sharding
:
false
cycle_length
:
4
shuffle_buffer_size
:
1000
tfds_as_supervised
:
false
tfds_data_dir
:
'
'
tfds_name
:
'
'
tfds_skip_decoding_feature
:
'
'
tfds_split
:
'
'
validation_data
:
block_length
:
1
cache
:
false
cycle_length
:
4
drop_remainder
:
false
global_batch_size
:
32000
input_path
:
'
the
path
to
validation
partition'
left_text_fields
:
[
'
src_raw'
]
right_text_fields
:
[
'
tgt_raw'
]
vocab_file
:
'
the
path
to
vocab.txt'
lower_case
:
false
is_training
:
false
seq_length
:
32
sharding
:
true
shuffle_buffer_size
:
1000
tfds_as_supervised
:
false
tfds_data_dir
:
'
'
tfds_name
:
'
'
tfds_skip_decoding_feature
:
'
'
tfds_split
:
'
'
trainer
:
checkpoint_interval
:
1000
eval_tf_function
:
true
max_to_keep
:
5
optimizer_config
:
learning_rate
:
polynomial
:
cycle
:
false
decay_steps
:
500000
end_learning_rate
:
0.0
initial_learning_rate
:
1.0e-04
name
:
PolynomialDecay
power
:
1.0
type
:
polynomial
optimizer
:
adamw
:
amsgrad
:
false
beta_1
:
0.9
beta_2
:
0.999
epsilon
:
1.0e-05
exclude_from_weight_decay
:
null
include_in_weight_decay
:
null
name
:
AdamWeightDecay
weight_decay_rate
:
0.0
gradient_clip_norm
:
100
type
:
adamw
warmup
:
polynomial
:
name
:
polynomial
power
:
1
warmup_steps
:
5000
type
:
polynomial
steps_per_loop
:
1000
summary_interval
:
1000
train_tf_function
:
true
train_tf_while_loop
:
true
train_steps
:
500000
validation_interval
:
1000
validation_steps
:
100
official/projects/labse/experiments/labse_bert_base.yaml
0 → 100644
View file @
d6668868
task
:
model
:
encoder
:
bert
:
attention_dropout_rate
:
0.1
dropout_rate
:
0.1
hidden_activation
:
gelu
hidden_size
:
768
initializer_range
:
0.02
intermediate_size
:
3072
max_position_embeddings
:
512
num_attention_heads
:
12
num_layers
:
12
type_vocab_size
:
2
vocab_size
:
501153
official/projects/labse/export_tfhub.py
0 → 100644
View file @
d6668868
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
r
"""Exports the LaBSE model and its preprocessing as SavedModels for TF Hub.
Example usage:
# Point this variable to your training results.
# Note that flag --do_lower_case is inferred from the name.
LaBSE_DIR=<Your LaBSE model dir>
# Step 1: export the core LaBSE model.
python3 ./export_tfhub.py \
--bert_config_file ${LaBSE_DIR:?}/bert_config.json \
--model_checkpoint_path ${LaBSE_DIR:?}/labse_model.ckpt \
--vocab_file ${LaBSE_DIR:?}/vocab.txt \
--export_type model --export_path /tmp/labse_model
# Step 2: export matching preprocessing (be sure to use same flags).
python3 ./export_tfhub.py \
--vocab_file ${LaBSE_DIR:?}/vocab.txt \
--export_type preprocessing --export_path /tmp/labse_preprocessing
"""
from
typing
import
Text
from
absl
import
app
from
absl
import
flags
from
absl
import
logging
import
tensorflow
as
tf
from
official.legacy.bert
import
bert_models
from
official.legacy.bert
import
configs
from
official.nlp.modeling
import
models
from
official.nlp.tasks
import
utils
from
official.nlp.tools
import
export_tfhub_lib
FLAGS
=
flags
.
FLAGS
flags
.
DEFINE_enum
(
"export_type"
,
"model"
,
[
"model"
,
"preprocessing"
],
"The type of model to export"
)
flags
.
DEFINE_string
(
"export_path"
,
None
,
"TF-Hub SavedModel destination path."
)
flags
.
DEFINE_string
(
"bert_tfhub_module"
,
None
,
"Bert tfhub module to define core bert layers. Needed for --export_type "
"model."
)
flags
.
DEFINE_string
(
"bert_config_file"
,
None
,
"Bert configuration file to define core bert layers. It will not be used "
"if bert_tfhub_module is set. Needed for --export_type model."
)
flags
.
DEFINE_string
(
"model_checkpoint_path"
,
None
,
"File path to TF model checkpoint. "
"Needed for --export_type model."
)
flags
.
DEFINE_string
(
"vocab_file"
,
None
,
"The vocabulary file that the BERT model was trained on. "
"Needed for both --export_type model and preprocessing."
)
flags
.
DEFINE_bool
(
"do_lower_case"
,
None
,
"Whether to lowercase before tokenization. If left as None, "
"do_lower_case will be enabled if 'uncased' appears in the "
"name of --vocab_file. "
"Needed for both --export_type model and preprocessing."
)
flags
.
DEFINE_integer
(
"default_seq_length"
,
128
,
"The sequence length of preprocessing results from "
"top-level preprocess method. This is also the default "
"sequence length for the bert_pack_inputs subobject."
"Needed for --export_type preprocessing."
)
flags
.
DEFINE_bool
(
"tokenize_with_offsets"
,
False
,
# TODO(b/181866850)
"Whether to export a .tokenize_with_offsets subobject for "
"--export_type preprocessing."
)
flags
.
DEFINE_bool
(
"normalize"
,
True
,
"Parameter of DualEncoder model, normalize the embedding (pooled_output) "
"if set to True."
)
def
_get_do_lower_case
(
do_lower_case
,
vocab_file
):
"""Returns do_lower_case, replacing None by a guess from vocab file name."""
if
do_lower_case
is
None
:
do_lower_case
=
"uncased"
in
vocab_file
logging
.
info
(
"Using do_lower_case=%s based on name of vocab_file=%s"
,
do_lower_case
,
vocab_file
)
return
do_lower_case
def
create_labse_model
(
bert_tfhub_module
:
Text
,
bert_config
:
configs
.
BertConfig
,
normalize
:
bool
)
->
tf
.
keras
.
Model
:
"""Creates a LaBSE keras core model from BERT configuration.
Args:
bert_tfhub_module: The bert tfhub module path. The LaBSE will be built upon
the tfhub module if it is not empty.
bert_config: A `BertConfig` to create the core model. Used if
bert_tfhub_module is empty.
normalize: Parameter of DualEncoder model, normalize the embedding (
pooled_output) if set to True.
Returns:
A keras model.
"""
if
bert_tfhub_module
:
encoder_network
=
utils
.
get_encoder_from_hub
(
bert_tfhub_module
)
else
:
encoder_network
=
bert_models
.
get_transformer_encoder
(
bert_config
,
sequence_length
=
None
)
labse_model
=
models
.
DualEncoder
(
network
=
encoder_network
,
max_seq_length
=
None
,
normalize
=
normalize
,
output
=
"predictions"
)
return
labse_model
,
encoder_network
# pytype: disable=bad-return-type # typed-keras
def
export_labse_model
(
bert_tfhub_module
:
Text
,
bert_config
:
configs
.
BertConfig
,
model_checkpoint_path
:
Text
,
hub_destination
:
Text
,
vocab_file
:
Text
,
do_lower_case
:
bool
,
normalize
:
bool
):
"""Restores a tf.keras.Model and saves for TF-Hub."""
core_model
,
encoder
=
create_labse_model
(
bert_tfhub_module
,
bert_config
,
normalize
)
checkpoint
=
tf
.
train
.
Checkpoint
(
encoder
=
encoder
)
checkpoint
.
restore
(
model_checkpoint_path
).
assert_existing_objects_matched
()
core_model
.
vocab_file
=
tf
.
saved_model
.
Asset
(
vocab_file
)
core_model
.
do_lower_case
=
tf
.
Variable
(
do_lower_case
,
trainable
=
False
)
core_model
.
save
(
hub_destination
,
include_optimizer
=
False
,
save_format
=
"tf"
)
def
main
(
_
):
do_lower_case
=
export_tfhub_lib
.
get_do_lower_case
(
FLAGS
.
do_lower_case
,
FLAGS
.
vocab_file
)
if
FLAGS
.
export_type
==
"model"
:
if
FLAGS
.
bert_tfhub_module
:
bert_config
=
None
else
:
bert_config
=
configs
.
BertConfig
.
from_json_file
(
FLAGS
.
bert_config_file
)
export_labse_model
(
FLAGS
.
bert_tfhub_module
,
bert_config
,
FLAGS
.
model_checkpoint_path
,
FLAGS
.
export_path
,
FLAGS
.
vocab_file
,
do_lower_case
,
FLAGS
.
normalize
)
elif
FLAGS
.
export_type
==
"preprocessing"
:
# LaBSE is still a BERT model, reuse the export_bert_preprocessing here.
export_tfhub_lib
.
export_bert_preprocessing
(
FLAGS
.
export_path
,
FLAGS
.
vocab_file
,
do_lower_case
,
FLAGS
.
default_seq_length
,
FLAGS
.
tokenize_with_offsets
)
else
:
raise
app
.
UsageError
(
"Unknown value '%s' for flag --export_type"
)
if
__name__
==
"__main__"
:
app
.
run
(
main
)
official/projects/labse/export_tfhub_test.py
0 → 100644
View file @
d6668868
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests LaBSE's export_tfhub."""
import
os
# Import libraries
import
numpy
as
np
import
tensorflow
as
tf
import
tensorflow_hub
as
hub
from
official.legacy.bert
import
configs
from
official.projects.labse
import
export_tfhub
class
ExportModelTest
(
tf
.
test
.
TestCase
):
def
test_export_model
(
self
):
# Exports a savedmodel for TF-Hub
hidden_size
=
16
bert_config
=
configs
.
BertConfig
(
vocab_size
=
100
,
hidden_size
=
hidden_size
,
intermediate_size
=
32
,
max_position_embeddings
=
128
,
num_attention_heads
=
2
,
num_hidden_layers
=
1
)
labse_model
,
encoder
=
export_tfhub
.
create_labse_model
(
None
,
bert_config
,
normalize
=
True
)
model_checkpoint_dir
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"checkpoint"
)
checkpoint
=
tf
.
train
.
Checkpoint
(
encoder
=
encoder
)
checkpoint
.
save
(
os
.
path
.
join
(
model_checkpoint_dir
,
"test"
))
model_checkpoint_path
=
tf
.
train
.
latest_checkpoint
(
model_checkpoint_dir
)
vocab_file
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"uncased_vocab.txt"
)
with
tf
.
io
.
gfile
.
GFile
(
vocab_file
,
"w"
)
as
f
:
f
.
write
(
"dummy content"
)
hub_destination
=
os
.
path
.
join
(
self
.
get_temp_dir
(),
"hub"
)
export_tfhub
.
export_labse_model
(
None
,
# bert_tfhub_module
bert_config
,
model_checkpoint_path
,
hub_destination
,
vocab_file
,
do_lower_case
=
True
,
normalize
=
True
)
# Restores a hub KerasLayer.
hub_layer
=
hub
.
KerasLayer
(
hub_destination
,
trainable
=
True
)
if
hasattr
(
hub_layer
,
"resolved_object"
):
# Checks meta attributes.
self
.
assertTrue
(
hub_layer
.
resolved_object
.
do_lower_case
.
numpy
())
with
tf
.
io
.
gfile
.
GFile
(
hub_layer
.
resolved_object
.
vocab_file
.
asset_path
.
numpy
())
as
f
:
self
.
assertEqual
(
"dummy content"
,
f
.
read
())
# Checks the hub KerasLayer.
for
source_weight
,
hub_weight
in
zip
(
labse_model
.
trainable_weights
,
hub_layer
.
trainable_weights
):
self
.
assertAllClose
(
source_weight
.
numpy
(),
hub_weight
.
numpy
())
seq_length
=
10
dummy_ids
=
np
.
zeros
((
2
,
seq_length
),
dtype
=
np
.
int32
)
hub_outputs
=
hub_layer
([
dummy_ids
,
dummy_ids
,
dummy_ids
])
source_outputs
=
labse_model
([
dummy_ids
,
dummy_ids
,
dummy_ids
])
self
.
assertEqual
(
hub_outputs
[
"pooled_output"
].
shape
,
(
2
,
hidden_size
))
self
.
assertEqual
(
hub_outputs
[
"sequence_output"
].
shape
,
(
2
,
seq_length
,
hidden_size
))
for
output_name
in
source_outputs
:
self
.
assertAllClose
(
hub_outputs
[
output_name
].
numpy
(),
hub_outputs
[
output_name
].
numpy
())
# Test that training=True makes a difference (activates dropout).
def
_dropout_mean_stddev
(
training
,
num_runs
=
20
):
input_ids
=
np
.
array
([[
14
,
12
,
42
,
95
,
99
]],
np
.
int32
)
inputs
=
[
input_ids
,
np
.
ones_like
(
input_ids
),
np
.
zeros_like
(
input_ids
)]
outputs
=
np
.
concatenate
([
hub_layer
(
inputs
,
training
=
training
)[
"pooled_output"
]
for
_
in
range
(
num_runs
)
])
return
np
.
mean
(
np
.
std
(
outputs
,
axis
=
0
))
self
.
assertLess
(
_dropout_mean_stddev
(
training
=
False
),
1e-6
)
self
.
assertGreater
(
_dropout_mean_stddev
(
training
=
True
),
1e-3
)
# Test propagation of seq_length in shape inference.
input_word_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
)
input_mask
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
)
input_type_ids
=
tf
.
keras
.
layers
.
Input
(
shape
=
(
seq_length
,),
dtype
=
tf
.
int32
)
outputs
=
hub_layer
([
input_word_ids
,
input_mask
,
input_type_ids
])
self
.
assertEqual
(
outputs
[
"pooled_output"
].
shape
.
as_list
(),
[
None
,
hidden_size
])
self
.
assertEqual
(
outputs
[
"sequence_output"
].
shape
.
as_list
(),
[
None
,
seq_length
,
hidden_size
])
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/projects/labse/train.py
0 → 100644
View file @
d6668868
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorFlow Model Garden Labse training driver, register labse configs."""
# pylint: disable=unused-import
from
absl
import
app
from
official.common
import
flags
as
tfm_flags
from
official.nlp
import
tasks
from
official.nlp
import
train
from
official.projects.labse
import
config_labse
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
app
.
run
(
train
.
main
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment