Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
32e4ca51
Commit
32e4ca51
authored
Nov 28, 2023
by
qianyj
Browse files
Update code to v2.11.0
parents
9485aa1d
71060f67
Changes
775
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
48 additions
and
35 deletions
+48
-35
official/modeling/fast_training/progressive/policies.py
official/modeling/fast_training/progressive/policies.py
+1
-1
official/modeling/fast_training/progressive/train.py
official/modeling/fast_training/progressive/train.py
+1
-1
official/modeling/fast_training/progressive/train_lib.py
official/modeling/fast_training/progressive/train_lib.py
+1
-1
official/modeling/fast_training/progressive/train_lib_test.py
...cial/modeling/fast_training/progressive/train_lib_test.py
+1
-1
official/modeling/fast_training/progressive/trainer.py
official/modeling/fast_training/progressive/trainer.py
+1
-1
official/modeling/fast_training/progressive/trainer_test.py
official/modeling/fast_training/progressive/trainer_test.py
+7
-3
official/modeling/fast_training/progressive/utils.py
official/modeling/fast_training/progressive/utils.py
+3
-3
official/modeling/grad_utils.py
official/modeling/grad_utils.py
+1
-1
official/modeling/grad_utils_test.py
official/modeling/grad_utils_test.py
+1
-1
official/modeling/hyperparams/__init__.py
official/modeling/hyperparams/__init__.py
+1
-1
official/modeling/hyperparams/base_config.py
official/modeling/hyperparams/base_config.py
+1
-1
official/modeling/hyperparams/base_config_test.py
official/modeling/hyperparams/base_config_test.py
+1
-1
official/modeling/hyperparams/oneof.py
official/modeling/hyperparams/oneof.py
+1
-1
official/modeling/hyperparams/oneof_test.py
official/modeling/hyperparams/oneof_test.py
+1
-1
official/modeling/hyperparams/params_dict.py
official/modeling/hyperparams/params_dict.py
+12
-12
official/modeling/hyperparams/params_dict_test.py
official/modeling/hyperparams/params_dict_test.py
+1
-1
official/modeling/multitask/__init__.py
official/modeling/multitask/__init__.py
+1
-1
official/modeling/multitask/base_model.py
official/modeling/multitask/base_model.py
+10
-1
official/modeling/multitask/base_trainer.py
official/modeling/multitask/base_trainer.py
+1
-1
official/modeling/multitask/base_trainer_test.py
official/modeling/multitask/base_trainer_test.py
+1
-1
No files found.
Too many changes to show.
To preserve performance only
775 of 775+
files are displayed.
Plain diff
Email patch
official/modeling/fast_training/progressive/policies.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/modeling/fast_training/progressive/train.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/modeling/fast_training/progressive/train_lib.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/modeling/fast_training/progressive/train_lib_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/modeling/fast_training/progressive/trainer.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/modeling/fast_training/progressive/trainer_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -226,9 +226,13 @@ class TrainerWithMaskedLMTaskTest(tf.test.TestCase, parameterized.TestCase):
task
=
TestPolicy
(
None
,
config
.
task
)
trainer
=
trainer_lib
.
ProgressiveTrainer
(
config
,
task
,
self
.
get_temp_dir
())
if
mixed_precision_dtype
!=
'float16'
:
self
.
assertIsInstance
(
trainer
.
optimizer
,
tf
.
keras
.
optimizers
.
SGD
)
self
.
assertIsInstance
(
trainer
.
optimizer
,
(
tf
.
keras
.
optimizers
.
SGD
,
tf
.
keras
.
optimizers
.
legacy
.
SGD
))
elif
mixed_precision_dtype
==
'float16'
and
loss_scale
is
None
:
self
.
assertIsInstance
(
trainer
.
optimizer
,
tf
.
keras
.
optimizers
.
SGD
)
self
.
assertIsInstance
(
trainer
.
optimizer
,
(
tf
.
keras
.
optimizers
.
SGD
,
tf
.
keras
.
optimizers
.
legacy
.
SGD
))
metrics
=
trainer
.
train
(
tf
.
convert_to_tensor
(
5
,
dtype
=
tf
.
int32
))
self
.
assertIn
(
'training_loss'
,
metrics
)
...
...
official/modeling/fast_training/progressive/utils.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -18,10 +18,10 @@ from absl import logging
import
tensorflow
as
tf
# pylint: disable=g-direct-tensorflow-import
from
tensorflow.python.tra
ining.tracking
import
track
ing
from
tensorflow.python.tra
ckable
import
auto
track
able
class
VolatileTrackable
(
track
ing
.
AutoTrackable
):
class
VolatileTrackable
(
auto
track
able
.
AutoTrackable
):
"""A util class to keep Trackables that might change instances."""
def
__init__
(
self
,
**
kwargs
):
...
...
official/modeling/grad_utils.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/modeling/grad_utils_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/modeling/hyperparams/__init__.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/modeling/hyperparams/base_config.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/modeling/hyperparams/base_config_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/modeling/hyperparams/oneof.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/modeling/hyperparams/oneof_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/modeling/hyperparams/params_dict.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -41,15 +41,15 @@ _PARAM_RE = re.compile(
_CONST_VALUE_RE
=
re
.
compile
(
r
'(\d.*|-\d.*|None)'
)
# Yaml
loader
with an implicit resolver to parse float decimal and exponential
# Yaml
LOADER
with an implicit resolver to parse float decimal and exponential
# format. The regular experission parse the following cases:
# 1- Decimal number with an optional exponential term.
# 2- Integer number with an exponential term.
# 3- Decimal number with an optional exponential term.
# 4- Decimal number.
LOADER
=
yaml
.
SafeLoader
LOADER
.
add_implicit_resolver
(
_
LOADER
=
yaml
.
SafeLoader
_
LOADER
.
add_implicit_resolver
(
'tag:yaml.org,2002:float'
,
re
.
compile
(
r
'''
^(?:[-+]?(?:[0-9][0-9_]*)\\.[0-9_]*(?:[eE][-+]?[0-9]+)?
...
...
@@ -288,42 +288,42 @@ class ParamsDict(object):
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
!=
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
'Found inconsist
e
ncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'!='
in
restriction
:
tokens
=
restriction
.
split
(
'!='
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
==
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
'Found inconsist
e
ncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'<'
in
restriction
:
tokens
=
restriction
.
split
(
'<'
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
>=
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
'Found inconsist
e
ncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'<='
in
restriction
:
tokens
=
restriction
.
split
(
'<='
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
>
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
'Found inconsist
e
ncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'>'
in
restriction
:
tokens
=
restriction
.
split
(
'>'
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
<=
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
'Found inconsist
e
ncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
elif
'>='
in
restriction
:
tokens
=
restriction
.
split
(
'>='
)
_
,
left_v
,
_
,
right_v
=
_get_kvs
(
tokens
,
params_dict
)
if
left_v
<
right_v
:
raise
KeyError
(
'Found inconsistncy between key `{}` and key `{}`.'
.
format
(
'Found inconsist
e
ncy between key `{}` and key `{}`.'
.
format
(
tokens
[
0
],
tokens
[
1
]))
else
:
raise
ValueError
(
'Unsupported relation in restriction.'
)
...
...
@@ -332,7 +332,7 @@ class ParamsDict(object):
def
read_yaml_to_params_dict
(
file_path
:
str
):
"""Reads a YAML file to a ParamsDict."""
with
tf
.
io
.
gfile
.
GFile
(
file_path
,
'r'
)
as
f
:
params_dict
=
yaml
.
load
(
f
,
Loader
=
LOADER
)
params_dict
=
yaml
.
load
(
f
,
Loader
=
_
LOADER
)
return
ParamsDict
(
params_dict
)
...
...
@@ -453,7 +453,7 @@ def override_params_dict(params, dict_or_string_or_yaml_file, is_strict):
nested_csv_str_to_json_str
(
dict_or_string_or_yaml_file
))
except
ValueError
:
pass
params_dict
=
yaml
.
load
(
dict_or_string_or_yaml_file
,
Loader
=
LOADER
)
params_dict
=
yaml
.
load
(
dict_or_string_or_yaml_file
,
Loader
=
_
LOADER
)
if
isinstance
(
params_dict
,
dict
):
params
.
override
(
params_dict
,
is_strict
)
else
:
...
...
official/modeling/hyperparams/params_dict_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/modeling/multitask/__init__.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/modeling/multitask/base_model.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
@@ -43,3 +43,12 @@ class MultiTaskBaseModel(tf.Module):
def
initialize
(
self
):
"""Optional function that loads a pre-train checkpoint."""
return
def
build
(
self
):
"""Builds the networks for tasks to make sure variables are created."""
# Try to build all sub tasks.
for
task_model
in
self
.
_sub_tasks
.
values
():
# Assumes all the tf.Module models are built because we don't have any
# way to check them.
if
isinstance
(
task_model
,
tf
.
keras
.
Model
)
and
not
task_model
.
built
:
_
=
task_model
(
task_model
.
inputs
)
official/modeling/multitask/base_trainer.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
official/modeling/multitask/base_trainer_test.py
View file @
32e4ca51
# Copyright 202
1
The TensorFlow Authors. All Rights Reserved.
# Copyright 202
2
The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
...
...
Prev
1
…
9
10
11
12
13
14
15
16
17
…
39
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment