Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
ab834d35
Commit
ab834d35
authored
Sep 10, 2020
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Sep 10, 2020
Browse files
Use distribution utils in XLNET
PiperOrigin-RevId: 331015243
parent
7d5c47aa
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
15 additions
and
46 deletions
+15
-46
official/nlp/xlnet/run_classifier.py
official/nlp/xlnet/run_classifier.py
+4
-14
official/nlp/xlnet/run_pretrain.py
official/nlp/xlnet/run_pretrain.py
+7
-18
official/nlp/xlnet/run_squad.py
official/nlp/xlnet/run_squad.py
+4
-14
No files found.
official/nlp/xlnet/run_classifier.py
View file @
ab834d35
...
...
@@ -14,11 +14,6 @@
# ==============================================================================
"""XLNet classification finetuning runner in tf2.0."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
functools
# Import libraries
from
absl
import
app
...
...
@@ -34,7 +29,7 @@ from official.nlp.xlnet import optimization
from
official.nlp.xlnet
import
training_utils
from
official.nlp.xlnet
import
xlnet_config
from
official.nlp.xlnet
import
xlnet_modeling
as
modeling
from
official.utils.misc
import
tpu_lib
from
official.utils.misc
import
distribution_utils
flags
.
DEFINE_integer
(
"n_class"
,
default
=
2
,
help
=
"Number of classes."
)
flags
.
DEFINE_string
(
...
...
@@ -135,14 +130,9 @@ def get_metric_fn():
def
main
(
unused_argv
):
del
unused_argv
if
FLAGS
.
strategy_type
==
"mirror"
:
strategy
=
tf
.
distribute
.
MirroredStrategy
()
elif
FLAGS
.
strategy_type
==
"tpu"
:
cluster_resolver
=
tpu_lib
.
tpu_initialize
(
FLAGS
.
tpu
)
strategy
=
tf
.
distribute
.
experimental
.
TPUStrategy
(
cluster_resolver
)
else
:
raise
ValueError
(
"The distribution strategy type is not supported: %s"
%
FLAGS
.
strategy_type
)
strategy
=
distribution_utils
.
get_distribution_strategy
(
distribution_strategy
=
FLAGS
.
strategy_type
,
tpu_address
=
FLAGS
.
tpu
)
if
strategy
:
logging
.
info
(
"***** Number of cores used : %d"
,
strategy
.
num_replicas_in_sync
)
...
...
official/nlp/xlnet/run_pretrain.py
View file @
ab834d35
...
...
@@ -12,12 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""XLNet classification finetuning runner in tf2.0."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
"""XLNet pretraining runner in tf2.0."""
import
functools
import
os
...
...
@@ -34,7 +29,7 @@ from official.nlp.xlnet import optimization
from
official.nlp.xlnet
import
training_utils
from
official.nlp.xlnet
import
xlnet_config
from
official.nlp.xlnet
import
xlnet_modeling
as
modeling
from
official.utils.misc
import
tpu_lib
from
official.utils.misc
import
distribution_utils
flags
.
DEFINE_integer
(
"num_predict"
,
...
...
@@ -77,17 +72,11 @@ def get_pretrainxlnet_model(model_config, run_config):
def
main
(
unused_argv
):
del
unused_argv
num_hosts
=
1
if
FLAGS
.
strategy_type
==
"mirror"
:
strategy
=
tf
.
distribute
.
MirroredStrategy
()
elif
FLAGS
.
strategy_type
==
"tpu"
:
cluster_resolver
=
tpu_lib
.
tpu_initialize
(
FLAGS
.
tpu
)
strategy
=
tf
.
distribute
.
experimental
.
TPUStrategy
(
cluster_resolver
)
topology
=
FLAGS
.
tpu_topology
.
split
(
"x"
)
total_num_core
=
2
*
int
(
topology
[
0
])
*
int
(
topology
[
1
])
num_hosts
=
total_num_core
//
FLAGS
.
num_core_per_host
else
:
raise
ValueError
(
"The distribution strategy type is not supported: %s"
%
FLAGS
.
strategy_type
)
strategy
=
distribution_utils
.
get_distribution_strategy
(
distribution_strategy
=
FLAGS
.
strategy_type
,
tpu_address
=
FLAGS
.
tpu
)
if
FLAGS
.
strategy_type
==
"tpu"
:
num_hosts
=
strategy
.
extended
.
num_hosts
if
strategy
:
logging
.
info
(
"***** Number of cores used : %d"
,
strategy
.
num_replicas_in_sync
)
...
...
official/nlp/xlnet/run_squad.py
View file @
ab834d35
...
...
@@ -14,11 +14,6 @@
# ==============================================================================
"""XLNet SQUAD finetuning runner in tf2.0."""
from
__future__
import
absolute_import
from
__future__
import
division
# from __future__ import google_type_annotations
from
__future__
import
print_function
import
functools
import
json
import
os
...
...
@@ -39,7 +34,7 @@ from official.nlp.xlnet import squad_utils
from
official.nlp.xlnet
import
training_utils
from
official.nlp.xlnet
import
xlnet_config
from
official.nlp.xlnet
import
xlnet_modeling
as
modeling
from
official.utils.misc
import
tpu_lib
from
official.utils.misc
import
distribution_utils
flags
.
DEFINE_string
(
"test_feature_path"
,
default
=
None
,
help
=
"Path to feature of test set."
)
...
...
@@ -217,14 +212,9 @@ def get_qaxlnet_model(model_config, run_config, start_n_top, end_n_top):
def
main
(
unused_argv
):
del
unused_argv
if
FLAGS
.
strategy_type
==
"mirror"
:
strategy
=
tf
.
distribute
.
MirroredStrategy
()
elif
FLAGS
.
strategy_type
==
"tpu"
:
cluster_resolver
=
tpu_lib
.
tpu_initialize
(
FLAGS
.
tpu
)
strategy
=
tf
.
distribute
.
experimental
.
TPUStrategy
(
cluster_resolver
)
else
:
raise
ValueError
(
"The distribution strategy type is not supported: %s"
%
FLAGS
.
strategy_type
)
strategy
=
distribution_utils
.
get_distribution_strategy
(
distribution_strategy
=
FLAGS
.
strategy_type
,
tpu_address
=
FLAGS
.
tpu
)
if
strategy
:
logging
.
info
(
"***** Number of cores used : %d"
,
strategy
.
num_replicas_in_sync
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment