Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
0bd9a8b2
Commit
0bd9a8b2
authored
Aug 03, 2021
by
Le Hou
Committed by
A. Unique TensorFlower
Aug 03, 2021
Browse files
Internal change
PiperOrigin-RevId: 388575593
parent
cc30e3e9
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
305 additions
and
18 deletions
+305
-18
official/modeling/fast_training/experimental/tf2_utils_2x_wide.py
.../modeling/fast_training/experimental/tf2_utils_2x_wide.py
+186
-0
official/modeling/fast_training/experimental/tf2_utils_2x_wide_test.py
...ling/fast_training/experimental/tf2_utils_2x_wide_test.py
+101
-0
official/modeling/fast_training/progressive/policies.py
official/modeling/fast_training/progressive/policies.py
+2
-2
official/modeling/fast_training/progressive/train.py
official/modeling/fast_training/progressive/train.py
+1
-1
official/modeling/fast_training/progressive/train_lib.py
official/modeling/fast_training/progressive/train_lib.py
+1
-1
official/modeling/fast_training/progressive/train_lib_test.py
...cial/modeling/fast_training/progressive/train_lib_test.py
+3
-3
official/modeling/fast_training/progressive/trainer.py
official/modeling/fast_training/progressive/trainer.py
+4
-4
official/modeling/fast_training/progressive/trainer_test.py
official/modeling/fast_training/progressive/trainer_test.py
+2
-2
official/modeling/fast_training/progressive/utils.py
official/modeling/fast_training/progressive/utils.py
+0
-0
official/nlp/projects/mobilebert/distillation.py
official/nlp/projects/mobilebert/distillation.py
+2
-2
official/nlp/projects/mobilebert/distillation_test.py
official/nlp/projects/mobilebert/distillation_test.py
+1
-1
official/nlp/projects/mobilebert/run_distillation.py
official/nlp/projects/mobilebert/run_distillation.py
+2
-2
No files found.
official/modeling/fast_training/experimental/tf2_utils_2x_wide.py
0 → 100644
View file @
0bd9a8b2
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Stacking model horizontally."""
from
absl
import
logging
import
numpy
as
np
import
tensorflow
as
tf
def
expand_vector
(
v
:
np
.
ndarray
)
->
np
.
ndarray
:
"""Expands a vector with batch dimensions.
Equivalent to expand_1_axis(v, epsilon=0.0, axis=-1)
Args:
v: A vector with shape [..., a].
Returns:
A vector with shape [..., 2 * a].
"""
return
np
.
repeat
(
v
,
2
,
axis
=-
1
)
def
expand_1_axis
(
w
:
np
.
ndarray
,
epsilon
:
float
,
axis
:
int
)
->
np
.
ndarray
:
"""Expands either the first dimension or the last dimension of w.
If `axis = 0`, the following constraint will be satisfied:
matmul(x, w) ==
matmul(expand_vector(x), expand_1_axis(w, epsilon=0.1, axis=0))
If `axis = -1`, the following constraint will be satisfied if `epsilon = 0.0`:
expand_vector(matmul(x, w)) ==
2 * matmul(x, expand_1_axis(w, epsilon=0.0, axis=-1))
Args:
w: Numpy array of shape [a_0, a_1, ..., a_i-1, a_i].
epsilon: Symmetric Noise added to expanded tensor.
axis: Must be either 0 or -1.
Returns:
Expanded numpy array.
"""
assert
axis
in
(
0
,
-
1
),
(
"Only support expanding the first or the last dimension. "
"Got: {}"
.
format
(
axis
))
rank
=
len
(
w
.
shape
)
d_w
=
np
.
random
.
normal
(
np
.
zeros_like
(
w
),
np
.
fabs
(
w
)
*
epsilon
,
w
.
shape
)
d_w
=
np
.
repeat
(
d_w
,
2
,
axis
=
axis
)
sign_flip
=
np
.
array
([
1
,
-
1
])
for
_
in
range
(
rank
-
1
):
sign_flip
=
np
.
expand_dims
(
sign_flip
,
axis
=-
1
if
axis
==
0
else
0
)
sign_flip
=
np
.
tile
(
sign_flip
,
[
w
.
shape
[
0
]]
+
[
1
]
*
(
rank
-
2
)
+
[
w
.
shape
[
-
1
]])
d_w
*=
sign_flip
w_expand
=
(
np
.
repeat
(
w
,
2
,
axis
=
axis
)
+
d_w
)
/
2
return
w_expand
def
expand_2_axes
(
w
:
np
.
ndarray
,
epsilon
:
float
)
->
np
.
ndarray
:
"""Expands the first dimension and the last dimension of w.
The following constraint will be satisfied:
expand_vector(matmul(x, w)) == matmul(expand_vector(x), expand_2_axes(w))
Args:
w: Numpy array of shape [a_0, a_1, ..., a_i-1, a_i].
epsilon: Symmetric Noise added to expanded tensor.
Returns:
Expanded numpy array.
"""
rank
=
len
(
w
.
shape
)
d_w
=
np
.
random
.
normal
(
np
.
zeros_like
(
w
),
np
.
fabs
(
w
)
*
epsilon
,
w
.
shape
)
d_w
=
np
.
repeat
(
np
.
repeat
(
d_w
,
2
,
axis
=
0
),
2
,
axis
=-
1
)
sign_flip
=
np
.
array
([
1
,
-
1
])
for
_
in
range
(
rank
-
1
):
sign_flip
=
np
.
expand_dims
(
sign_flip
,
axis
=-
1
)
sign_flip
=
np
.
tile
(
sign_flip
,
[
w
.
shape
[
0
]]
+
[
1
]
*
(
rank
-
2
)
+
[
w
.
shape
[
-
1
]
*
2
])
d_w
*=
sign_flip
w_expand
=
(
np
.
repeat
(
np
.
repeat
(
w
,
2
,
axis
=
0
),
2
,
axis
=-
1
)
+
d_w
)
/
2
return
w_expand
def
var_to_var
(
var_from
:
tf
.
Variable
,
var_to
:
tf
.
Variable
,
epsilon
:
float
):
"""Expands a variable to another variable.
Assume the shape of `var_from` is (a, b, ..., y, z), the shape of `var_to`
can be (a, ..., z * 2), (a * 2, ..., z * 2), (a * 2, ..., z)
If the shape of `var_to` is (a, ..., 2 * z):
For any x, tf.matmul(x, var_to) ~= expand_vector(tf.matmul(x, var_from)) / 2
Not that there will be noise added to the left hand side, if epsilon != 0.
If the shape of `var_to` is (2 * a, ..., z):
For any x, tf.matmul(expand_vector(x), var_to) == tf.matmul(x, var_from)
If the shape of `var_to` is (2 * a, ..., 2 * z):
For any x, tf.matmul(expand_vector(x), var_to) ==
expand_vector(tf.matmul(expand_vector(x), var_from))
Args:
var_from: input variable to expand.
var_to: output variable.
epsilon: the noise ratio that will be added, when splitting `var_from`.
"""
shape_from
=
var_from
.
shape
shape_to
=
var_to
.
shape
if
shape_from
==
shape_to
:
var_to
.
assign
(
var_from
)
elif
len
(
shape_from
)
==
1
and
len
(
shape_to
)
==
1
:
var_to
.
assign
(
expand_vector
(
var_from
.
numpy
()))
elif
shape_from
[
0
]
*
2
==
shape_to
[
0
]
and
shape_from
[
-
1
]
==
shape_to
[
-
1
]:
var_to
.
assign
(
expand_1_axis
(
var_from
.
numpy
(),
epsilon
=
epsilon
,
axis
=
0
))
elif
shape_from
[
0
]
==
shape_to
[
0
]
and
shape_from
[
-
1
]
*
2
==
shape_to
[
-
1
]:
var_to
.
assign
(
expand_1_axis
(
var_from
.
numpy
(),
epsilon
=
epsilon
,
axis
=-
1
))
elif
shape_from
[
0
]
*
2
==
shape_to
[
0
]
and
shape_from
[
-
1
]
*
2
==
shape_to
[
-
1
]:
var_to
.
assign
(
expand_2_axes
(
var_from
.
numpy
(),
epsilon
=
epsilon
))
else
:
raise
ValueError
(
"Shape not supported, {}, {}"
.
format
(
shape_from
,
shape_to
))
def
model_to_model_2x_wide
(
model_from
:
tf
.
Module
,
model_to
:
tf
.
Module
,
epsilon
:
float
=
0.1
):
"""Expands a model to a wider version.
Also makes sure that the output of the model is not changed after expanding.
For example:
```
model_narrow = tf.keras.Sequential()
model_narrow.add(tf.keras.Input(shape=(3,)))
model_narrow.add(tf.keras.layers.Dense(4))
model_narrow.add(tf.keras.layers.Dense(1))
model_wide = tf.keras.Sequential()
model_wide.add(tf.keras.Input(shape=(6,)))
model_wide.add(tf.keras.layers.Dense(8))
model_wide.add(tf.keras.layers.Dense(1))
model_to_model_2x_wide(model_narrow, model_wide)
assert model_narrow([[1, 2, 3]]) == model_wide([[1, 1, 2, 2, 3, 3]])
```
We assume that `model_from` and `model_to` has the same architecture and only
widths of them differ.
Args:
model_from: input model to expand.
model_to: output model whose variables will be assigned expanded values
according to `model_from`.
epsilon: the noise ratio that will be added, when splitting `var_from`.
"""
for
w_from
,
w_to
in
zip
(
model_from
.
trainable_variables
,
model_to
.
trainable_variables
):
logging
.
info
(
"expanding %s %s to %s %s"
,
w_from
.
name
,
w_from
.
shape
,
w_to
.
name
,
w_to
.
shape
)
var_to_var
(
w_from
,
w_to
,
epsilon
=
epsilon
)
official/modeling/fast_training/experimental/tf2_utils_2x_wide_test.py
0 → 100644
View file @
0bd9a8b2
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tf2_utils_2x_wide."""
import
numpy
as
np
import
tensorflow
as
tf
from
official.modeling.fast_training.experimental
import
tf2_utils_2x_wide
class
Tf2Utils2XWideTest
(
tf
.
test
.
TestCase
):
def
test_expand_vector
(
self
):
x
=
np
.
array
([
1
,
2
])
self
.
assertAllClose
(
tf2_utils_2x_wide
.
expand_vector
(
x
),
np
.
array
([
1
,
1
,
2
,
2
]))
def
test_expand_matrix
(
self
):
x
=
np
.
array
([[
1
,
2
],
[
3
,
4
]])
x
=
tf2_utils_2x_wide
.
expand_2_axes
(
x
,
epsilon
=
0.1
)
self
.
assertAllClose
(
x
[
0
,
:]
+
x
[
1
,
:],
np
.
array
([
1
,
1
,
2
,
2
]))
self
.
assertAllClose
(
x
[
2
,
:]
+
x
[
3
,
:],
np
.
array
([
3
,
3
,
4
,
4
]))
def
test_expand_matrix_axis_0
(
self
):
x
=
np
.
array
([[
1
,
2
],
[
3
,
4
]])
x
=
tf2_utils_2x_wide
.
expand_1_axis
(
x
,
axis
=
0
,
epsilon
=
0.1
)
self
.
assertAllClose
(
x
[
0
,
:]
+
x
[
1
,
:],
np
.
array
([
1
,
2
]))
self
.
assertAllClose
(
x
[
2
,
:]
+
x
[
3
,
:],
np
.
array
([
3
,
4
]))
def
test_expand_matrix_axis_1
(
self
):
x
=
np
.
array
([[
1
,
2
],
[
3
,
4
]])
x
=
tf2_utils_2x_wide
.
expand_1_axis
(
x
,
axis
=-
1
,
epsilon
=
0.1
)
self
.
assertAllClose
(
x
[:,
0
]
+
x
[:,
1
],
np
.
array
([
1
,
3
]))
self
.
assertAllClose
(
x
[:,
2
]
+
x
[:,
3
],
np
.
array
([
2
,
4
]))
def
test_expand_3d_tensor
(
self
):
x0
=
np
.
array
([
10
,
11
])
x1
=
np
.
array
([
10
,
10
,
11
,
11
])
w0
=
np
.
random
.
rand
(
2
,
2
)
w1
=
tf2_utils_2x_wide
.
expand_2_axes
(
w0
,
epsilon
=
0.1
)
o0
=
np
.
matmul
(
x0
,
w0
)
o1
=
np
.
matmul
(
x1
,
w1
)
self
.
assertAllClose
(
np
.
repeat
(
o0
,
2
,
axis
=-
1
),
o1
)
def
test_expand_3d_tensor_axis_0
(
self
):
x0
=
np
.
array
([
10
,
11
])
x1
=
np
.
array
([
10
,
10
,
11
,
11
])
w0
=
np
.
random
.
rand
(
2
,
2
)
w1
=
tf2_utils_2x_wide
.
expand_1_axis
(
w0
,
axis
=
0
,
epsilon
=
0.1
)
o0
=
np
.
matmul
(
x0
,
w0
)
o1
=
np
.
matmul
(
x1
,
w1
)
self
.
assertAllClose
(
o0
,
o1
)
def
test_expand_3d_tensor_axis_2
(
self
):
x
=
np
.
array
([
10
,
11
])
w0
=
np
.
random
.
rand
(
2
,
2
)
w1
=
tf2_utils_2x_wide
.
expand_1_axis
(
w0
,
axis
=-
1
,
epsilon
=
0.1
)
o0
=
np
.
matmul
(
x
,
w0
)
o1
=
np
.
matmul
(
x
,
w1
)
self
.
assertAllClose
(
o0
,
np
.
sum
(
o1
.
reshape
(
2
,
2
),
axis
=-
1
))
def
test_end_to_end
(
self
):
"""Covers expand_vector, expand_2_axes, and expand_1_axis."""
model_narrow
=
tf
.
keras
.
Sequential
()
model_narrow
.
add
(
tf
.
keras
.
Input
(
shape
=
(
3
,)))
model_narrow
.
add
(
tf
.
keras
.
layers
.
Dense
(
4
))
model_narrow
.
add
(
tf
.
keras
.
layers
.
Dense
(
4
))
model_narrow
.
add
(
tf
.
keras
.
layers
.
Dense
(
1
))
model_wide
=
tf
.
keras
.
Sequential
()
model_wide
.
add
(
tf
.
keras
.
Input
(
shape
=
(
6
,)))
model_wide
.
add
(
tf
.
keras
.
layers
.
Dense
(
8
))
model_wide
.
add
(
tf
.
keras
.
layers
.
Dense
(
8
))
model_wide
.
add
(
tf
.
keras
.
layers
.
Dense
(
1
))
x0
=
np
.
array
([[
1
,
2
,
3
]])
x1
=
np
.
array
([[
1
,
1
,
2
,
2
,
3
,
3
]])
# Call model once to build variables first.
_
,
_
=
model_narrow
(
x0
),
model_wide
(
x1
)
tf2_utils_2x_wide
.
model_to_model_2x_wide
(
model_narrow
,
model_wide
,
epsilon
=
0.2
)
self
.
assertAllClose
(
model_narrow
(
x0
),
model_wide
(
x1
),
rtol
=
1e-05
,
atol
=
1e-05
)
if
__name__
==
"__main__"
:
tf
.
test
.
main
()
official/modeling/progressive/policies.py
→
official/modeling/
fast_training/
progressive/policies.py
View file @
0bd9a8b2
...
@@ -19,13 +19,13 @@ abstract methods to handle each training stage.
...
@@ -19,13 +19,13 @@ abstract methods to handle each training stage.
"""
"""
import
abc
import
abc
import
dataclasses
from
typing
import
Any
,
Mapping
from
typing
import
Any
,
Mapping
from
absl
import
logging
from
absl
import
logging
import
dataclasses
import
six
import
six
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.modeling.fast_training.progressive
import
utils
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
base_config
from
official.modeling.progressive
import
utils
@
dataclasses
.
dataclass
@
dataclasses
.
dataclass
...
...
official/modeling/progressive/train.py
→
official/modeling/
fast_training/
progressive/train.py
View file @
0bd9a8b2
...
@@ -26,7 +26,7 @@ from official.common import flags as tfm_flags
...
@@ -26,7 +26,7 @@ from official.common import flags as tfm_flags
from
official.core
import
task_factory
from
official.core
import
task_factory
from
official.core
import
train_utils
from
official.core
import
train_utils
from
official.modeling
import
performance
from
official.modeling
import
performance
from
official.modeling.progressive
import
train_lib
from
official.modeling.
fast_training.
progressive
import
train_lib
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
...
official/modeling/progressive/train_lib.py
→
official/modeling/
fast_training/
progressive/train_lib.py
View file @
0bd9a8b2
...
@@ -29,7 +29,7 @@ import tensorflow as tf
...
@@ -29,7 +29,7 @@ import tensorflow as tf
from
official.core
import
base_task
from
official.core
import
base_task
from
official.core
import
config_definitions
from
official.core
import
config_definitions
from
official.core
import
train_lib
as
base_train_lib
from
official.core
import
train_lib
as
base_train_lib
from
official.modeling.progressive
import
trainer
as
prog_trainer_lib
from
official.modeling.
fast_training.
progressive
import
trainer
as
prog_trainer_lib
def
run_experiment
(
distribution_strategy
:
tf
.
distribute
.
Strategy
,
def
run_experiment
(
distribution_strategy
:
tf
.
distribute
.
Strategy
,
...
...
official/modeling/progressive/train_lib_test.py
→
official/modeling/
fast_training/
progressive/train_lib_test.py
View file @
0bd9a8b2
...
@@ -31,9 +31,9 @@ from official.core import config_definitions as cfg
...
@@ -31,9 +31,9 @@ from official.core import config_definitions as cfg
from
official.core
import
task_factory
from
official.core
import
task_factory
from
official.modeling
import
optimization
from
official.modeling
import
optimization
from
official.modeling.hyperparams
import
params_dict
from
official.modeling.hyperparams
import
params_dict
from
official.modeling.progressive
import
policies
from
official.modeling.
fast_training.
progressive
import
policies
from
official.modeling.progressive
import
train_lib
from
official.modeling.
fast_training.
progressive
import
train_lib
from
official.modeling.progressive
import
trainer
as
prog_trainer_lib
from
official.modeling.
fast_training.
progressive
import
trainer
as
prog_trainer_lib
from
official.utils.testing
import
mock_task
from
official.utils.testing
import
mock_task
FLAGS
=
flags
.
FLAGS
FLAGS
=
flags
.
FLAGS
...
...
official/modeling/progressive/trainer.py
→
official/modeling/
fast_training/
progressive/trainer.py
View file @
0bd9a8b2
...
@@ -18,21 +18,21 @@ The trainer implements the Orbit `StandardTrainable` and
...
@@ -18,21 +18,21 @@ The trainer implements the Orbit `StandardTrainable` and
`StandardEvaluable` interfaces. Trainers inside this project should be
`StandardEvaluable` interfaces. Trainers inside this project should be
interchangable and independent on model architectures and tasks.
interchangable and independent on model architectures and tasks.
"""
"""
import
dataclasses
import
os
import
os
from
typing
import
Any
,
Optional
from
typing
import
Any
,
Optional
# Import libraries
# Import libraries
from
absl
import
logging
from
absl
import
logging
import
dataclasses
import
gin
import
gin
import
orbit
import
orbit
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
base_task
from
official.core
import
base_trainer
as
trainer_lib
from
official.core
import
base_trainer
as
trainer_lib
from
official.core
import
config_definitions
from
official.core
import
config_definitions
from
official.modeling.progressive
import
policies
from
official.modeling.
fast_training.
progressive
import
policies
from
official.modeling.progressive
import
utils
from
official.modeling.
fast_training.
progressive
import
utils
ExperimentConfig
=
config_definitions
.
ExperimentConfig
ExperimentConfig
=
config_definitions
.
ExperimentConfig
...
...
official/modeling/progressive/trainer_test.py
→
official/modeling/
fast_training/
progressive/trainer_test.py
View file @
0bd9a8b2
...
@@ -24,8 +24,8 @@ from tensorflow.python.distribute import combinations
...
@@ -24,8 +24,8 @@ from tensorflow.python.distribute import combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
tensorflow.python.distribute
import
strategy_combinations
from
official.core
import
config_definitions
as
cfg
from
official.core
import
config_definitions
as
cfg
from
official.modeling
import
optimization
from
official.modeling
import
optimization
from
official.modeling.progressive
import
policies
from
official.modeling.
fast_training.
progressive
import
policies
from
official.modeling.progressive
import
trainer
as
trainer_lib
from
official.modeling.
fast_training.
progressive
import
trainer
as
trainer_lib
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
bert
from
official.utils.testing
import
mock_task
from
official.utils.testing
import
mock_task
...
...
official/modeling/progressive/utils.py
→
official/modeling/
fast_training/
progressive/utils.py
View file @
0bd9a8b2
File moved
official/nlp/projects/mobilebert/distillation.py
View file @
0bd9a8b2
...
@@ -13,18 +13,18 @@
...
@@ -13,18 +13,18 @@
# limitations under the License.
# limitations under the License.
"""Progressive distillation for MobileBERT student model."""
"""Progressive distillation for MobileBERT student model."""
import
dataclasses
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
from
absl
import
logging
from
absl
import
logging
import
dataclasses
import
orbit
import
orbit
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
base_task
from
official.core
import
config_definitions
as
cfg
from
official.core
import
config_definitions
as
cfg
from
official.modeling
import
optimization
from
official.modeling
import
optimization
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.modeling.fast_training.progressive
import
policies
from
official.modeling.hyperparams
import
base_config
from
official.modeling.hyperparams
import
base_config
from
official.modeling.progressive
import
policies
from
official.nlp
import
keras_nlp
from
official.nlp
import
keras_nlp
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.configs
import
encoders
...
...
official/nlp/projects/mobilebert/distillation_test.py
View file @
0bd9a8b2
...
@@ -22,7 +22,7 @@ import tensorflow as tf
...
@@ -22,7 +22,7 @@ import tensorflow as tf
from
official.core
import
config_definitions
as
cfg
from
official.core
import
config_definitions
as
cfg
from
official.modeling
import
optimization
from
official.modeling
import
optimization
from
official.modeling
import
tf_utils
from
official.modeling
import
tf_utils
from
official.modeling.progressive
import
trainer
as
prog_trainer_lib
from
official.modeling.
fast_training.
progressive
import
trainer
as
prog_trainer_lib
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
bert
from
official.nlp.configs
import
encoders
from
official.nlp.configs
import
encoders
from
official.nlp.data
import
pretrain_dataloader
from
official.nlp.data
import
pretrain_dataloader
...
...
official/nlp/projects/mobilebert/run_distillation.py
View file @
0bd9a8b2
...
@@ -28,8 +28,8 @@ from official.core import train_utils
...
@@ -28,8 +28,8 @@ from official.core import train_utils
from
official.modeling
import
hyperparams
from
official.modeling
import
hyperparams
from
official.modeling
import
optimization
from
official.modeling
import
optimization
from
official.modeling
import
performance
from
official.modeling
import
performance
from
official.modeling.progressive
import
train_lib
from
official.modeling.
fast_training.
progressive
import
train_lib
from
official.modeling.progressive
import
trainer
as
prog_trainer_lib
from
official.modeling.
fast_training.
progressive
import
trainer
as
prog_trainer_lib
from
official.nlp.data
import
pretrain_dataloader
from
official.nlp.data
import
pretrain_dataloader
from
official.nlp.projects.mobilebert
import
distillation
from
official.nlp.projects.mobilebert
import
distillation
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment