Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
40926c83
Commit
40926c83
authored
Oct 18, 2022
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 482019933
parent
3d10262e
Changes
2
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
1016 additions
and
0 deletions
+1016
-0
official/nlp/modeling/layers/moe.py
official/nlp/modeling/layers/moe.py
+761
-0
official/nlp/modeling/layers/moe_test.py
official/nlp/modeling/layers/moe_test.py
+255
-0
No files found.
official/nlp/modeling/layers/moe.py
0 → 100644
View file @
40926c83
This diff is collapsed.
Click to expand it.
official/nlp/modeling/layers/moe_test.py
0 → 100644
View file @
40926c83
# Copyright 2022 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for moe.py."""
import
ml_collections
import
numpy
as
np
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
moe
def
small_config
()
->
ml_collections
.
ConfigDict
:
"""Creates a small model config that can be used by all tests."""
config
=
ml_collections
.
ConfigDict
()
config
.
d_ff
=
32
config
.
dropout_rate
=
0.1
config
.
num_experts
=
2
config
.
expert_d_ff
=
33
config
.
expert_dropout_rate
=
0.1
config
.
jitter_noise
=
0.1
config
.
train_capacity_factor
=
1.0
config
.
eval_capacity_factor
=
1.0
config
.
min_expert_capacity
=
1
config
.
max_group_size
=
9
config
.
backbone_d_ff
=
13
return
config
def
make_input_ones
(
batch_size
:
int
=
2
,
seq_length
:
int
=
10
,
hidden_dim
:
int
=
7
)
->
tf
.
Tensor
:
return
tf
.
ones
((
batch_size
,
seq_length
,
hidden_dim
),
dtype
=
tf
.
float32
)
def
make_experts_input_ones
(
num_groups
:
int
=
1
,
num_experts
:
int
=
2
,
expert_capacity
:
int
=
5
,
hidden_dim
:
int
=
7
)
->
tf
.
Tensor
:
return
tf
.
ones
((
num_groups
,
num_experts
,
expert_capacity
,
hidden_dim
),
dtype
=
tf
.
float32
)
class
MoeTest
(
tf
.
test
.
TestCase
):
def
tearDown
(
self
):
super
().
tearDown
()
tf
.
keras
.
mixed_precision
.
set_global_policy
(
'float32'
)
def
test_router_z_loss_dtype
(
self
):
x
=
tf
.
constant
([[[
10.0
,
5.0
]]],
dtype
=
tf
.
float32
)
y
=
moe
.
_router_z_loss
(
x
)
expected
=
(
5
+
np
.
log
(
np
.
exp
(
5
)
+
1
))
**
2
self
.
assertAllClose
(
expected
,
y
,
atol
=
1e-7
)
x
=
tf
.
constant
([[[
10.0
,
5.0
]]],
dtype
=
tf
.
bfloat16
)
y
=
moe
.
_router_z_loss
(
x
)
expected
=
100.0
self
.
assertAllClose
(
expected
,
y
,
atol
=
1e-7
)
def
test_router_z_loss_shape
(
self
):
x
=
make_input_ones
(
2
,
5
,
7
)
y
=
moe
.
_router_z_loss
(
x
)
expected
=
(
np
.
log
(
7
)
+
1
)
**
2
self
.
assertAllClose
(
expected
,
y
,
atol
=
1e-7
)
def
test_experts_choose_masked_router_dtype_shape
(
self
):
tf
.
keras
.
mixed_precision
.
set_global_policy
(
'mixed_bfloat16'
)
num_groups
=
2
tokens_per_group
=
3
hidden_dim
=
tokens_per_group
num_experts
=
tokens_per_group
expert_capacity
=
2
x
=
np
.
zeros
([
num_groups
,
tokens_per_group
,
hidden_dim
])
x
[
0
,
0
,
0
]
+=
1
x
[
0
,
:
2
,
:
2
]
+=
1
x
[
1
,
1
:,
1
:]
+=
1
x
[
1
,
-
1
,
-
1
]
+=
1
router
=
moe
.
ExpertsChooseMaskedRouter
(
num_experts
=
num_experts
,
jitter_noise
=
0.1
,
use_bias
=
True
,
kernel_initializer
=
tf
.
keras
.
initializers
.
get
(
'identity'
),
bias_initializer
=
tf
.
keras
.
initializers
.
get
(
'ones'
))
router_mask
=
router
(
x
,
expert_capacity
=
expert_capacity
,
training
=
False
)
self
.
assertDTypeEqual
(
router_mask
.
dispatch_mask
,
tf
.
bfloat16
)
self
.
assertDTypeEqual
(
router_mask
.
combine_array
,
tf
.
bfloat16
)
expect_shape
=
[
num_groups
,
tokens_per_group
,
num_experts
,
expert_capacity
]
self
.
assertEqual
(
expect_shape
,
router_mask
.
dispatch_mask
.
shape
)
self
.
assertEqual
(
expect_shape
,
router_mask
.
combine_array
.
shape
)
# top_k call may not be sorted, so can't compare the output directly
# Check that the output contains only 0s and 1s
out_dm
=
router_mask
.
dispatch_mask
.
numpy
()
self
.
assertSetEqual
({
0
,
1
},
set
(
out_dm
.
flatten
().
astype
(
np
.
int32
)))
# Check that the right tokens for selected
out_dm_indices
=
np
.
dot
(
out_dm
.
transpose
((
0
,
2
,
3
,
1
)),
np
.
arange
(
tokens_per_group
))
# Shape [num_groups, num_experts, expert_capacity]
self
.
assertSetEqual
({
0
,
1
},
set
(
out_dm_indices
[
0
,
0
,
:].
astype
(
np
.
int32
)))
self
.
assertSetEqual
({
1
,
2
},
set
(
out_dm_indices
[
0
,
1
,
:].
astype
(
np
.
int32
)))
self
.
assertSetEqual
({
1
,
2
},
set
(
out_dm_indices
[
0
,
2
,
:].
astype
(
np
.
int32
)))
self
.
assertSetEqual
({
0
,
1
},
set
(
out_dm_indices
[
1
,
0
,
:].
astype
(
np
.
int32
)))
self
.
assertSetEqual
({
0
,
1
},
set
(
out_dm_indices
[
1
,
1
,
:].
astype
(
np
.
int32
)))
self
.
assertSetEqual
({
1
,
2
},
set
(
out_dm_indices
[
1
,
2
,
:].
astype
(
np
.
int32
)))
out_ca
=
router_mask
.
combine_array
.
numpy
()
out_ca
=
np
.
dot
(
out_ca
,
np
.
ones
((
expert_capacity
,)))
expected_combine_array
=
np
.
array
(
[[[
0.66
,
0.0
,
0.0
],
[
0.42
,
0.42
,
0.16
],
[
0.0
,
0.33
,
0.33
]],
[[
0.33
,
0.33
,
0.0
],
[
0.16
,
0.42
,
0.42
],
[
0.0
,
0.0
,
0.66
]]])
self
.
assertAllClose
(
expected_combine_array
,
out_ca
,
atol
=
1e-2
)
def
test_feed_forward_shape_and_vars
(
self
):
config
=
small_config
()
layer
=
moe
.
FeedForward
(
d_ff
=
config
.
d_ff
,
dropout_rate
=
config
.
dropout_rate
)
inputs
=
make_input_ones
()
outputs
=
layer
(
inputs
)
self
.
assertAllEqual
(
tf
.
shape
(
inputs
),
tf
.
shape
(
outputs
))
var_names
=
sorted
([
v
.
name
for
v
in
layer
.
trainable_variables
])
self
.
assertAllEqual
([
'feed_forward/intermediate/bias:0'
,
'feed_forward/intermediate/kernel:0'
,
'feed_forward/output/bias:0'
,
'feed_forward/output/kernel:0'
],
var_names
)
def
test_feed_forward_manual
(
self
):
config
=
small_config
()
layer
=
moe
.
FeedForward
(
d_ff
=
config
.
d_ff
,
dropout_rate
=
config
.
dropout_rate
,
activation
=
tf
.
keras
.
activations
.
relu
,
kernel_initializer
=
tf
.
keras
.
initializers
.
get
(
'ones'
),
bias_initializer
=
tf
.
keras
.
initializers
.
get
(
'ones'
))
inputs
=
make_input_ones
(
1
,
2
,
3
)
outputs
=
layer
(
inputs
,
training
=
False
)
manual_outputs
=
tf
.
constant
([[[
129.0
,
129.0
,
129.0
],
[
129.0
,
129.0
,
129.0
]]])
self
.
assertAllClose
(
manual_outputs
,
outputs
,
atol
=
1e-7
)
def
test_feed_forward_experts_shape_and_vars
(
self
):
config
=
small_config
()
layer
=
moe
.
FeedForwardExperts
(
num_experts
=
config
.
num_experts
,
d_ff
=
config
.
expert_d_ff
,
dropout_rate
=
config
.
expert_dropout_rate
)
inputs
=
make_experts_input_ones
()
outputs
=
layer
(
inputs
)
self
.
assertAllEqual
(
tf
.
shape
(
inputs
),
tf
.
shape
(
outputs
))
var_names
=
sorted
([
v
.
name
for
v
in
layer
.
trainable_variables
])
self
.
assertAllEqual
([
'experts/intermediate/bias:0'
,
'experts/intermediate/kernel:0'
,
'experts/output/bias:0'
,
'experts/output/kernel:0'
],
var_names
)
def
test_feed_forward_experts_manual
(
self
):
config
=
small_config
()
layer
=
moe
.
FeedForwardExperts
(
num_experts
=
1
,
d_ff
=
config
.
expert_d_ff
,
dropout_rate
=
config
.
expert_dropout_rate
,
activation
=
tf
.
keras
.
activations
.
relu
,
kernel_initializer
=
tf
.
keras
.
initializers
.
get
(
'ones'
),
bias_initializer
=
tf
.
keras
.
initializers
.
get
(
'ones'
))
inputs
=
make_experts_input_ones
(
1
,
1
,
2
,
3
)
outputs
=
layer
(
inputs
,
training
=
False
)
manual_outputs
=
tf
.
constant
([[[[
133.0
,
133.0
,
133.0
],
[
133.0
,
133.0
,
133.0
]]]])
self
.
assertAllClose
(
manual_outputs
,
outputs
,
atol
=
1e-7
)
def
test_moe_layer
(
self
):
config
=
small_config
()
experts
=
moe
.
FeedForwardExperts
(
num_experts
=
config
.
num_experts
,
d_ff
=
config
.
expert_d_ff
,
dropout_rate
=
config
.
expert_dropout_rate
)
router
=
moe
.
ExpertsChooseMaskedRouter
(
config
.
num_experts
,
jitter_noise
=
config
.
jitter_noise
)
moe_layer
=
moe
.
MoeLayer
(
experts
,
router
,
train_capacity_factor
=
config
.
train_capacity_factor
,
eval_capacity_factor
=
config
.
eval_capacity_factor
,
max_group_size
=
config
.
max_group_size
,
min_expert_capacity
=
config
.
min_expert_capacity
)
inputs
=
make_input_ones
()
with
self
.
assertLogs
(
'absl'
,
level
=
'INFO'
)
as
cm
:
outputs
=
moe_layer
(
inputs
,
training
=
True
)
self
.
assertAllEqual
(
tf
.
shape
(
inputs
),
tf
.
shape
(
outputs
))
self
.
assertEqual
(
cm
.
output
,
[
(
'INFO:absl:Selected group_size=5 and num_groups=4 for input '
'num_tokens=20, max_group_size=9, num_experts=2.'
),
(
'INFO:absl:Selected expert_capacity=2 for num_experts=2 and '
'training=True.'
)])
var_names
=
sorted
([
v
.
name
for
v
in
moe_layer
.
trainable_variables
])
self
.
assertAllEqual
([
'moe/experts/intermediate/bias:0'
,
'moe/experts/intermediate/kernel:0'
,
'moe/experts/output/bias:0'
,
'moe/experts/output/kernel:0'
,
'moe/router/router_weights/bias:0'
,
'moe/router/router_weights/kernel:0'
],
var_names
)
self
.
assertLen
(
moe_layer
.
losses
,
1
)
metrics
=
[
metric
.
name
for
metric
in
moe_layer
.
metrics
]
self
.
assertSetEqual
(
{
'router_z_loss'
,
'load_balancing_loss'
,
'fraction_tokens_left_behind'
,
'router_confidence'
,
'expert_usage'
},
set
(
metrics
))
def
test_moe_layer_with_backbone
(
self
):
config
=
small_config
()
experts
=
moe
.
FeedForwardExperts
(
num_experts
=
config
.
num_experts
,
d_ff
=
config
.
expert_d_ff
,
dropout_rate
=
config
.
expert_dropout_rate
)
router
=
moe
.
ExpertsChooseMaskedRouter
(
config
.
num_experts
,
jitter_noise
=
config
.
jitter_noise
)
moe_layer
=
moe
.
MoeLayer
(
experts
,
router
,
train_capacity_factor
=
config
.
train_capacity_factor
,
eval_capacity_factor
=
config
.
eval_capacity_factor
,
max_group_size
=
config
.
max_group_size
,
min_expert_capacity
=
config
.
min_expert_capacity
)
layer
=
moe
.
MoeLayerWithBackbone
(
moe_layer
,
config
.
backbone_d_ff
)
inputs
=
make_input_ones
()
outputs
=
layer
(
inputs
)
self
.
assertAllEqual
(
tf
.
shape
(
inputs
),
tf
.
shape
(
outputs
))
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment