Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
4bd15fa6
Commit
4bd15fa6
authored
Sep 05, 2019
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 267435985
parent
a009f4fb
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
86 additions
and
97 deletions
+86
-97
official/bert/modeling.py
official/bert/modeling.py
+33
-9
official/transformer/v2/attention_layer.py
official/transformer/v2/attention_layer.py
+42
-80
official/transformer/v2/transformer.py
official/transformer/v2/transformer.py
+6
-4
official/transformer/v2/transformer_layers_test.py
official/transformer/v2/transformer_layers_test.py
+5
-4
No files found.
official/bert/modeling.py
View file @
4bd15fa6
...
@@ -495,7 +495,23 @@ class Attention(tf.keras.layers.Layer):
...
@@ -495,7 +495,23 @@ class Attention(tf.keras.layers.Layer):
class
Dense3D
(
tf
.
keras
.
layers
.
Layer
):
class
Dense3D
(
tf
.
keras
.
layers
.
Layer
):
"""A Dense Layer using 3D kernel with tf.einsum implementation."""
"""A Dense Layer using 3D kernel with tf.einsum implementation.
Attributes:
num_attention_heads: An integer, number of attention heads for each
multihead attention layer.
size_per_head: An integer, hidden size per attention head.
hidden_size: An integer, dimension of the hidden layer.
kernel_initializer: An initializer for the kernel weight.
bias_initializer: An initializer for the bias.
activation: An activation function to use. If nothing is specified, no
activation is applied.
use_bias: A bool, whether the layer uses a bias.
output_projection: A bool, whether the Dense3D layer is used for output
linear projection.
backward_compatible: A bool, whether the variables shape are compatible
with checkpoints converted from TF 1.x.
"""
def
__init__
(
self
,
def
__init__
(
self
,
num_attention_heads
=
12
,
num_attention_heads
=
12
,
...
@@ -503,9 +519,11 @@ class Dense3D(tf.keras.layers.Layer):
...
@@ -503,9 +519,11 @@ class Dense3D(tf.keras.layers.Layer):
kernel_initializer
=
None
,
kernel_initializer
=
None
,
bias_initializer
=
"zeros"
,
bias_initializer
=
"zeros"
,
activation
=
None
,
activation
=
None
,
use_bias
=
True
,
output_projection
=
False
,
output_projection
=
False
,
backward_compatible
=
False
,
backward_compatible
=
False
,
**
kwargs
):
**
kwargs
):
"""Inits Dense3D."""
super
(
Dense3D
,
self
).
__init__
(
**
kwargs
)
super
(
Dense3D
,
self
).
__init__
(
**
kwargs
)
self
.
num_attention_heads
=
num_attention_heads
self
.
num_attention_heads
=
num_attention_heads
self
.
size_per_head
=
size_per_head
self
.
size_per_head
=
size_per_head
...
@@ -513,6 +531,7 @@ class Dense3D(tf.keras.layers.Layer):
...
@@ -513,6 +531,7 @@ class Dense3D(tf.keras.layers.Layer):
self
.
kernel_initializer
=
kernel_initializer
self
.
kernel_initializer
=
kernel_initializer
self
.
bias_initializer
=
bias_initializer
self
.
bias_initializer
=
bias_initializer
self
.
activation
=
activation
self
.
activation
=
activation
self
.
use_bias
=
use_bias
self
.
output_projection
=
output_projection
self
.
output_projection
=
output_projection
self
.
backward_compatible
=
backward_compatible
self
.
backward_compatible
=
backward_compatible
...
@@ -565,12 +584,15 @@ class Dense3D(tf.keras.layers.Layer):
...
@@ -565,12 +584,15 @@ class Dense3D(tf.keras.layers.Layer):
initializer
=
self
.
kernel_initializer
,
initializer
=
self
.
kernel_initializer
,
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
trainable
=
True
)
trainable
=
True
)
self
.
bias
=
self
.
add_weight
(
if
self
.
use_bias
:
"bias"
,
self
.
bias
=
self
.
add_weight
(
shape
=
bias_shape
,
"bias"
,
initializer
=
self
.
bias_initializer
,
shape
=
bias_shape
,
dtype
=
self
.
dtype
,
initializer
=
self
.
bias_initializer
,
trainable
=
True
)
dtype
=
self
.
dtype
,
trainable
=
True
)
else
:
self
.
bias
=
None
super
(
Dense3D
,
self
).
build
(
input_shape
)
super
(
Dense3D
,
self
).
build
(
input_shape
)
def
call
(
self
,
inputs
):
def
call
(
self
,
inputs
):
...
@@ -588,7 +610,8 @@ class Dense3D(tf.keras.layers.Layer):
...
@@ -588,7 +610,8 @@ class Dense3D(tf.keras.layers.Layer):
"""
"""
if
self
.
backward_compatible
:
if
self
.
backward_compatible
:
kernel
=
tf
.
keras
.
backend
.
reshape
(
self
.
kernel
,
self
.
kernel_shape
)
kernel
=
tf
.
keras
.
backend
.
reshape
(
self
.
kernel
,
self
.
kernel_shape
)
bias
=
tf
.
keras
.
backend
.
reshape
(
self
.
bias
,
self
.
bias_shape
)
bias
=
(
tf
.
keras
.
backend
.
reshape
(
self
.
bias
,
self
.
bias_shape
)
if
self
.
use_bias
else
None
)
else
:
else
:
kernel
=
self
.
kernel
kernel
=
self
.
kernel
bias
=
self
.
bias
bias
=
self
.
bias
...
@@ -597,7 +620,8 @@ class Dense3D(tf.keras.layers.Layer):
...
@@ -597,7 +620,8 @@ class Dense3D(tf.keras.layers.Layer):
ret
=
tf
.
einsum
(
"abcd,cde->abe"
,
inputs
,
kernel
)
ret
=
tf
.
einsum
(
"abcd,cde->abe"
,
inputs
,
kernel
)
else
:
else
:
ret
=
tf
.
einsum
(
"abc,cde->abde"
,
inputs
,
kernel
)
ret
=
tf
.
einsum
(
"abc,cde->abde"
,
inputs
,
kernel
)
ret
+=
bias
if
self
.
use_bias
:
ret
+=
bias
if
self
.
activation
is
not
None
:
if
self
.
activation
is
not
None
:
return
self
.
activation
(
ret
)
return
self
.
activation
(
ret
)
return
ret
return
ret
...
...
official/transformer/v2/attention_layer.py
View file @
4bd15fa6
...
@@ -19,6 +19,7 @@ from __future__ import division
...
@@ -19,6 +19,7 @@ from __future__ import division
from
__future__
import
print_function
from
__future__
import
print_function
import
tensorflow
as
tf
import
tensorflow
as
tf
from
official.bert
import
modeling
as
common_layer
class
Attention
(
tf
.
keras
.
layers
.
Layer
):
class
Attention
(
tf
.
keras
.
layers
.
Layer
):
...
@@ -45,14 +46,19 @@ class Attention(tf.keras.layers.Layer):
...
@@ -45,14 +46,19 @@ class Attention(tf.keras.layers.Layer):
def
build
(
self
,
input_shape
):
def
build
(
self
,
input_shape
):
"""Builds the layer."""
"""Builds the layer."""
# Layers for linearly projecting the queries, keys, and values.
# Layers for linearly projecting the queries, keys, and values.
self
.
q_dense_layer
=
tf
.
keras
.
layers
.
Dense
(
size_per_head
=
self
.
hidden_size
//
self
.
num_heads
self
.
hidden_size
,
use_bias
=
False
,
name
=
"q"
)
self
.
query_dense_layer
=
common_layer
.
Dense3D
(
self
.
k_dense_layer
=
tf
.
keras
.
layers
.
Dense
(
self
.
num_heads
,
size_per_head
,
kernel_initializer
=
"glorot_uniform"
,
self
.
hidden_size
,
use_bias
=
False
,
name
=
"k"
)
use_bias
=
False
,
name
=
"query"
)
self
.
v_dense_layer
=
tf
.
keras
.
layers
.
Dense
(
self
.
key_dense_layer
=
common_layer
.
Dense3D
(
self
.
hidden_size
,
use_bias
=
False
,
name
=
"v"
)
self
.
num_heads
,
size_per_head
,
kernel_initializer
=
"glorot_uniform"
,
self
.
output_dense_layer
=
tf
.
keras
.
layers
.
Dense
(
use_bias
=
False
,
name
=
"key"
)
self
.
hidden_size
,
use_bias
=
False
,
name
=
"output_transform"
)
self
.
value_dense_layer
=
common_layer
.
Dense3D
(
self
.
num_heads
,
size_per_head
,
kernel_initializer
=
"glorot_uniform"
,
use_bias
=
False
,
name
=
"value"
)
self
.
output_dense_layer
=
common_layer
.
Dense3D
(
self
.
num_heads
,
size_per_head
,
kernel_initializer
=
"glorot_uniform"
,
use_bias
=
False
,
output_projection
=
True
,
name
=
"output_transform"
)
super
(
Attention
,
self
).
build
(
input_shape
)
super
(
Attention
,
self
).
build
(
input_shape
)
def
get_config
(
self
):
def
get_config
(
self
):
...
@@ -62,73 +68,35 @@ class Attention(tf.keras.layers.Layer):
...
@@ -62,73 +68,35 @@ class Attention(tf.keras.layers.Layer):
"attention_dropout"
:
self
.
attention_dropout
,
"attention_dropout"
:
self
.
attention_dropout
,
}
}
def
split_heads
(
self
,
x
):
def
call
(
self
,
query_input
,
source_input
,
bias
,
training
,
cache
=
None
,
"""Split x into different heads, and transpose the resulting value.
decode_loop_step
=
None
):
"""Apply attention mechanism to query_input and source_input.
The tensor is transposed to insure the inner dimensions hold the correct
values during the matrix multiplication.
Args:
x: A tensor with shape [batch_size, length, hidden_size]
Returns:
A tensor with shape [batch_size, num_heads, length, hidden_size/num_heads]
"""
with
tf
.
name_scope
(
"split_heads"
):
batch_size
=
tf
.
shape
(
x
)[
0
]
length
=
tf
.
shape
(
x
)[
1
]
# Calculate depth of last dimension after it has been split.
depth
=
(
self
.
hidden_size
//
self
.
num_heads
)
# Split the last dimension
x
=
tf
.
reshape
(
x
,
[
batch_size
,
length
,
self
.
num_heads
,
depth
])
# Transpose the result
return
tf
.
transpose
(
x
,
[
0
,
2
,
1
,
3
])
def
combine_heads
(
self
,
x
):
"""Combine tensor that has been split.
Args:
Args:
x: A tensor [batch_size, num_heads, length, hidden_size/num_heads]
query_input: A tensor with shape [batch_size, length_query, hidden_size].
source_input: A tensor with shape [batch_size, length_source,
Returns:
hidden_size].
A tensor with shape [batch_size, length, hidden_size]
bias: A tensor with shape [batch_size, 1, length_query, length_source],
"""
the attention bias that will be added to the result of the dot product.
with
tf
.
name_scope
(
"combine_heads"
):
batch_size
=
tf
.
shape
(
x
)[
0
]
length
=
tf
.
shape
(
x
)[
2
]
x
=
tf
.
transpose
(
x
,
[
0
,
2
,
1
,
3
])
# --> [batch, length, num_heads, depth]
return
tf
.
reshape
(
x
,
[
batch_size
,
length
,
self
.
hidden_size
])
def
call
(
self
,
x
,
y
,
bias
,
training
,
cache
=
None
,
decode_loop_step
=
None
):
"""Apply attention mechanism to x and y.
Args:
x: A tensor with shape [batch_size, length_x, hidden_size].
y: A tensor with shape [batch_size, length_y, hidden_size].
bias: A bool, the attention bias that will be added to the result of the
dot product.
training: A bool, whether in training mode or not.
training: A bool, whether in training mode or not.
cache: (Used during prediction) A dictionary with tensors containing
cache: (Used during prediction) A dictionary with tensors containing
results of previous attentions. The dictionary must have the items:
results of previous attentions. The dictionary must have the items:
{"k": tensor with shape [batch_size, i, key_channels],
{"k": tensor with shape [batch_size, i, heads, dim_per_head],
"v": tensor with shape [batch_size, i, value_channels]}
"v": tensor with shape [batch_size, i, heads, dim_per_head]}
where i is the current decoded length.
where i is the current decoded length for non-padded decode, or max
sequence length for padded decode.
decode_loop_step: An integer, step number of the decoding loop. Used only
decode_loop_step: An integer, step number of the decoding loop. Used only
for autoregressive inference on TPU.
for autoregressive inference on TPU.
Returns:
Returns:
Attention layer output with shape [batch_size, length_
x
, hidden_size]
Attention layer output with shape [batch_size, length_
query
, hidden_size]
"""
"""
# Linearly project the query, key and value using different learned
# Linearly project the query, key and value using different learned
# projections. This is in preparation of splitting them into multiple
# projections. Splitting heads is automatically done during the linear
# heads. Multi-head attention uses multiple queries, keys, and values
# projections --> [batch_size, length, num_heads, dim_per_head].
# rather than regular attention (which uses a single query, key, value).
query
=
self
.
query_dense_layer
(
query_input
)
query
=
self
.
q_dense_layer
(
x
)
key
=
self
.
key_dense_layer
(
source_input
)
key
=
self
.
k_dense_layer
(
y
)
value
=
self
.
value_dense_layer
(
source_input
)
value
=
self
.
v_dense_layer
(
y
)
if
cache
is
not
None
:
if
cache
is
not
None
:
# Combine cached keys and values with new keys and values.
# Combine cached keys and values with new keys and values.
...
@@ -136,12 +104,12 @@ class Attention(tf.keras.layers.Layer):
...
@@ -136,12 +104,12 @@ class Attention(tf.keras.layers.Layer):
cache_k_shape
=
cache
[
"k"
].
shape
.
as_list
()
cache_k_shape
=
cache
[
"k"
].
shape
.
as_list
()
indices
=
tf
.
reshape
(
indices
=
tf
.
reshape
(
tf
.
one_hot
(
decode_loop_step
,
cache_k_shape
[
1
],
dtype
=
key
.
dtype
),
tf
.
one_hot
(
decode_loop_step
,
cache_k_shape
[
1
],
dtype
=
key
.
dtype
),
[
1
,
cache_k_shape
[
1
],
1
])
[
1
,
cache_k_shape
[
1
],
1
,
1
])
key
=
cache
[
"k"
]
+
key
*
indices
key
=
cache
[
"k"
]
+
key
*
indices
cache_v_shape
=
cache
[
"v"
].
shape
.
as_list
()
cache_v_shape
=
cache
[
"v"
].
shape
.
as_list
()
indices
=
tf
.
reshape
(
indices
=
tf
.
reshape
(
tf
.
one_hot
(
decode_loop_step
,
cache_v_shape
[
1
],
dtype
=
value
.
dtype
),
tf
.
one_hot
(
decode_loop_step
,
cache_v_shape
[
1
],
dtype
=
value
.
dtype
),
[
1
,
cache_v_shape
[
1
],
1
])
[
1
,
cache_v_shape
[
1
],
1
,
1
])
value
=
cache
[
"v"
]
+
value
*
indices
value
=
cache
[
"v"
]
+
value
*
indices
else
:
else
:
key
=
tf
.
concat
([
tf
.
cast
(
cache
[
"k"
],
key
.
dtype
),
key
],
axis
=
1
)
key
=
tf
.
concat
([
tf
.
cast
(
cache
[
"k"
],
key
.
dtype
),
key
],
axis
=
1
)
...
@@ -151,18 +119,13 @@ class Attention(tf.keras.layers.Layer):
...
@@ -151,18 +119,13 @@ class Attention(tf.keras.layers.Layer):
cache
[
"k"
]
=
key
cache
[
"k"
]
=
key
cache
[
"v"
]
=
value
cache
[
"v"
]
=
value
# Split query, key, value into heads.
query
=
self
.
split_heads
(
query
)
key
=
self
.
split_heads
(
key
)
value
=
self
.
split_heads
(
value
)
# Scale query to prevent the dot product between query and key from growing
# Scale query to prevent the dot product between query and key from growing
# too large.
# too large.
depth
=
(
self
.
hidden_size
//
self
.
num_heads
)
depth
=
(
self
.
hidden_size
//
self
.
num_heads
)
query
*=
depth
**
-
0.5
query
*=
depth
**
-
0.5
# Calculate dot product attention
# Calculate dot product attention
logits
=
tf
.
matmul
(
query
,
key
,
transpose_b
=
True
)
logits
=
tf
.
einsum
(
"BTNH,BFNH->BNFT"
,
key
,
query
)
logits
+=
bias
logits
+=
bias
# Note that softmax internally performs math operations using float32
# Note that softmax internally performs math operations using float32
# for numeric stability. When training with float16, we keep the input
# for numeric stability. When training with float16, we keep the input
...
@@ -170,12 +133,10 @@ class Attention(tf.keras.layers.Layer):
...
@@ -170,12 +133,10 @@ class Attention(tf.keras.layers.Layer):
weights
=
tf
.
nn
.
softmax
(
logits
,
name
=
"attention_weights"
)
weights
=
tf
.
nn
.
softmax
(
logits
,
name
=
"attention_weights"
)
if
training
:
if
training
:
weights
=
tf
.
nn
.
dropout
(
weights
,
rate
=
self
.
attention_dropout
)
weights
=
tf
.
nn
.
dropout
(
weights
,
rate
=
self
.
attention_dropout
)
attention_output
=
tf
.
matmul
(
weights
,
value
)
attention_output
=
tf
.
einsum
(
"BNFT,BTNH->BFNH"
,
weights
,
value
)
# Recombine heads --> [batch_size, length, hidden_size]
attention_output
=
self
.
combine_heads
(
attention_output
)
# Run the combined outputs through another linear projection layer.
# Run the outputs through another linear projection layer. Recombining heads
# is automatically done --> [batch_size, length, hidden_size]
attention_output
=
self
.
output_dense_layer
(
attention_output
)
attention_output
=
self
.
output_dense_layer
(
attention_output
)
return
attention_output
return
attention_output
...
@@ -183,6 +144,7 @@ class Attention(tf.keras.layers.Layer):
...
@@ -183,6 +144,7 @@ class Attention(tf.keras.layers.Layer):
class
SelfAttention
(
Attention
):
class
SelfAttention
(
Attention
):
"""Multiheaded self-attention layer."""
"""Multiheaded self-attention layer."""
def
call
(
self
,
x
,
bias
,
training
,
cache
=
None
,
decode_loop_step
=
None
):
def
call
(
self
,
query_input
,
bias
,
training
,
cache
=
None
,
return
super
(
SelfAttention
,
self
).
call
(
x
,
x
,
bias
,
training
,
cache
,
decode_loop_step
=
None
):
decode_loop_step
)
return
super
(
SelfAttention
,
self
).
call
(
query_input
,
query_input
,
bias
,
training
,
cache
,
decode_loop_step
)
official/transformer/v2/transformer.py
View file @
4bd15fa6
...
@@ -200,7 +200,7 @@ class Transformer(tf.keras.Model):
...
@@ -200,7 +200,7 @@ class Transformer(tf.keras.Model):
# Prepare inputs to decoder layers by shifting targets, adding positional
# Prepare inputs to decoder layers by shifting targets, adding positional
# encoding and applying dropout.
# encoding and applying dropout.
decoder_inputs
=
self
.
embedding_softmax_layer
(
targets
)
decoder_inputs
=
self
.
embedding_softmax_layer
(
targets
)
decoder_inputs
=
tf
.
cast
(
decoder_inputs
,
self
.
params
[
'
dtype
'
])
decoder_inputs
=
tf
.
cast
(
decoder_inputs
,
self
.
params
[
"
dtype
"
])
attention_bias
=
tf
.
cast
(
attention_bias
,
self
.
params
[
"dtype"
])
attention_bias
=
tf
.
cast
(
attention_bias
,
self
.
params
[
"dtype"
])
with
tf
.
name_scope
(
"shift_targets"
):
with
tf
.
name_scope
(
"shift_targets"
):
# Shift targets to the right, and remove the last element
# Shift targets to the right, and remove the last element
...
@@ -218,7 +218,7 @@ class Transformer(tf.keras.Model):
...
@@ -218,7 +218,7 @@ class Transformer(tf.keras.Model):
# Run values
# Run values
decoder_self_attention_bias
=
model_utils
.
get_decoder_self_attention_bias
(
decoder_self_attention_bias
=
model_utils
.
get_decoder_self_attention_bias
(
length
,
dtype
=
self
.
params
[
'
dtype
'
])
length
,
dtype
=
self
.
params
[
"
dtype
"
])
outputs
=
self
.
decoder_stack
(
outputs
=
self
.
decoder_stack
(
decoder_inputs
,
decoder_inputs
,
encoder_outputs
,
encoder_outputs
,
...
@@ -310,16 +310,18 @@ class Transformer(tf.keras.Model):
...
@@ -310,16 +310,18 @@ class Transformer(tf.keras.Model):
# pylint: disable=g-complex-comprehension
# pylint: disable=g-complex-comprehension
init_decode_length
=
(
init_decode_length
=
(
max_decode_length
if
self
.
params
[
"padded_decode"
]
else
0
)
max_decode_length
if
self
.
params
[
"padded_decode"
]
else
0
)
num_heads
=
self
.
params
[
"num_heads"
]
dim_per_head
=
self
.
params
[
"hidden_size"
]
//
num_heads
cache
=
{
cache
=
{
"layer_%d"
%
layer
:
{
"layer_%d"
%
layer
:
{
"k"
:
"k"
:
tf
.
zeros
([
tf
.
zeros
([
batch_size
,
init_decode_length
,
self
.
params
[
"hidden_size"
]
batch_size
,
init_decode_length
,
num_heads
,
dim_per_head
],
],
dtype
=
self
.
params
[
"dtype"
]),
dtype
=
self
.
params
[
"dtype"
]),
"v"
:
"v"
:
tf
.
zeros
([
tf
.
zeros
([
batch_size
,
init_decode_length
,
self
.
params
[
"hidden_size"
]
batch_size
,
init_decode_length
,
num_heads
,
dim_per_head
],
],
dtype
=
self
.
params
[
"dtype"
])
dtype
=
self
.
params
[
"dtype"
])
}
for
layer
in
range
(
self
.
params
[
"num_hidden_layers"
])
}
for
layer
in
range
(
self
.
params
[
"num_hidden_layers"
])
...
...
official/transformer/v2/transformer_layers_test.py
View file @
4bd15fa6
...
@@ -32,6 +32,7 @@ class TransformerLayersTest(tf.test.TestCase):
...
@@ -32,6 +32,7 @@ class TransformerLayersTest(tf.test.TestCase):
hidden_size
=
64
hidden_size
=
64
num_heads
=
4
num_heads
=
4
dropout
=
0.5
dropout
=
0.5
dim_per_head
=
hidden_size
//
num_heads
layer
=
attention_layer
.
SelfAttention
(
hidden_size
,
num_heads
,
dropout
)
layer
=
attention_layer
.
SelfAttention
(
hidden_size
,
num_heads
,
dropout
)
self
.
assertDictEqual
(
layer
.
get_config
(),
{
self
.
assertDictEqual
(
layer
.
get_config
(),
{
"hidden_size"
:
hidden_size
,
"hidden_size"
:
hidden_size
,
...
@@ -42,13 +43,13 @@ class TransformerLayersTest(tf.test.TestCase):
...
@@ -42,13 +43,13 @@ class TransformerLayersTest(tf.test.TestCase):
x
=
tf
.
ones
([
1
,
length
,
hidden_size
])
x
=
tf
.
ones
([
1
,
length
,
hidden_size
])
bias
=
tf
.
ones
([
1
])
bias
=
tf
.
ones
([
1
])
cache
=
{
cache
=
{
"k"
:
tf
.
zeros
([
1
,
0
,
hidden_size
]),
"k"
:
tf
.
zeros
([
1
,
0
,
num_heads
,
dim_per_head
]),
"v"
:
tf
.
zeros
([
1
,
0
,
hidden_size
]),
"v"
:
tf
.
zeros
([
1
,
0
,
num_heads
,
dim_per_head
]),
}
}
y
=
layer
(
x
,
bias
,
training
=
True
,
cache
=
cache
)
y
=
layer
(
x
,
bias
,
training
=
True
,
cache
=
cache
)
self
.
assertEqual
(
y
.
shape
,
(
1
,
length
,
64
,))
self
.
assertEqual
(
y
.
shape
,
(
1
,
length
,
64
,))
self
.
assertEqual
(
cache
[
"k"
].
shape
,
(
1
,
length
,
64
,))
self
.
assertEqual
(
cache
[
"k"
].
shape
,
(
1
,
length
,
num_heads
,
dim_per_head
,))
self
.
assertEqual
(
cache
[
"v"
].
shape
,
(
1
,
length
,
64
,))
self
.
assertEqual
(
cache
[
"v"
].
shape
,
(
1
,
length
,
num_heads
,
dim_per_head
,))
def
test_embedding_shared_weights
(
self
):
def
test_embedding_shared_weights
(
self
):
vocab_size
=
50
vocab_size
=
50
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment