Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
1e9cbdce
Commit
1e9cbdce
authored
Mar 20, 2021
by
Jeremiah Liu
Committed by
A. Unique TensorFlower
Mar 20, 2021
Browse files
Adds spectral normalization and Gaussian process layers.
PiperOrigin-RevId: 364109398
parent
1a8f129b
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
1068 additions
and
0 deletions
+1068
-0
official/nlp/modeling/layers/gaussian_process.py
official/nlp/modeling/layers/gaussian_process.py
+460
-0
official/nlp/modeling/layers/gaussian_process_test.py
official/nlp/modeling/layers/gaussian_process_test.py
+227
-0
official/nlp/modeling/layers/spectral_normalization.py
official/nlp/modeling/layers/spectral_normalization.py
+295
-0
official/nlp/modeling/layers/spectral_normalization_test.py
official/nlp/modeling/layers/spectral_normalization_test.py
+86
-0
No files found.
official/nlp/modeling/layers/gaussian_process.py
0 → 100644
View file @
1e9cbdce
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Definitions for random feature Gaussian process layer.
## References:
[1]: Ali Rahimi and Benjamin Recht. Random Features for Large-Scale Kernel
Machines. In _Neural Information Processing Systems_, 2007.
https://people.eecs.berkeley.edu/~brecht/papers/07.rah.rec.nips.pdf
"""
import
math
import
tensorflow
as
tf
_SUPPORTED_LIKELIHOOD
=
(
'binary_logistic'
,
'poisson'
,
'gaussian'
)
class
RandomFeatureGaussianProcess
(
tf
.
keras
.
layers
.
Layer
):
"""Gaussian process layer with random feature approximation.
During training, the model updates the maximum a posteriori (MAP) logits
estimates and posterior precision matrix using minibatch statistics. During
inference, the model divides the MAP logit estimates by the predictive
standard deviation, which is equivalent to approximating the posterior mean
of the predictive probability via the mean-field approximation.
User can specify different types of random features by setting
`use_custom_random_features=True`, and change the initializer and activations
of the custom random features. For example:
MLP Kernel: initializer='random_normal', activation=tf.nn.relu
RBF Kernel: initializer='random_normal', activation=tf.math.cos
A linear kernel can also be specified by setting gp_kernel_type='linear' and
`use_custom_random_features=True`.
Attributes:
units: (int) The dimensionality of layer.
num_inducing: (int) The number of random features for the approximation.
is_training: (tf.bool) Whether the layer is set in training mode. If so the
layer updates the Gaussian process' variance estimate using statistics
computed from the incoming minibatches.
"""
def
__init__
(
self
,
units
,
num_inducing
=
1024
,
gp_kernel_type
=
'gaussian'
,
gp_kernel_scale
=
1.
,
gp_output_bias
=
0.
,
normalize_input
=
True
,
gp_kernel_scale_trainable
=
False
,
gp_output_bias_trainable
=
False
,
gp_cov_momentum
=
0.999
,
gp_cov_ridge_penalty
=
1e-6
,
scale_random_features
=
True
,
use_custom_random_features
=
True
,
custom_random_features_initializer
=
None
,
custom_random_features_activation
=
None
,
l2_regularization
=
0.
,
gp_cov_likelihood
=
'gaussian'
,
return_gp_cov
=
True
,
return_random_features
=
False
,
dtype
=
None
,
name
=
'random_feature_gaussian_process'
,
**
gp_output_kwargs
):
"""Initializes a random-feature Gaussian process layer instance.
Args:
units: (int) Number of output units.
num_inducing: (int) Number of random Fourier features used for
approximating the Gaussian process.
gp_kernel_type: (string) The type of kernel function to use for Gaussian
process. Currently default to 'gaussian' which is the Gaussian RBF
kernel.
gp_kernel_scale: (float) The length-scale parameter of the a
shift-invariant kernel function, i.e., for RBF kernel:
exp(-|x1 - x2|**2 / gp_kernel_scale).
gp_output_bias: (float) Scalar initial value for the bias vector.
normalize_input: (bool) Whether to normalize the input to Gaussian
process.
gp_kernel_scale_trainable: (bool) Whether the length scale variable is
trainable.
gp_output_bias_trainable: (bool) Whether the bias is trainable.
gp_cov_momentum: (float) A discount factor used to compute the moving
average for posterior covariance matrix.
gp_cov_ridge_penalty: (float) Initial Ridge penalty to posterior
covariance matrix.
scale_random_features: (bool) Whether to scale the random feature
by sqrt(2. / num_inducing).
use_custom_random_features: (bool) Whether to use custom random
features implemented using tf.keras.layers.Dense.
custom_random_features_initializer: (str or callable) Initializer for
the random features. Default to random normal which approximates a RBF
kernel function if activation function is cos.
custom_random_features_activation: (callable) Activation function for the
random feature layer. Default to cosine which approximates a RBF
kernel function.
l2_regularization: (float) The strength of l2 regularization on the output
weights.
gp_cov_likelihood: (string) Likelihood to use for computing Laplace
approximation for covariance matrix. Default to `gaussian`.
return_gp_cov: (bool) Whether to also return GP covariance matrix.
If False then no covariance learning is performed.
return_random_features: (bool) Whether to also return random features.
dtype: (tf.DType) Input data type.
name: (string) Layer name.
**gp_output_kwargs: Additional keyword arguments to dense output layer.
"""
super
(
RandomFeatureGaussianProcess
,
self
).
__init__
(
name
=
name
,
dtype
=
dtype
)
self
.
units
=
units
self
.
num_inducing
=
num_inducing
self
.
normalize_input
=
normalize_input
self
.
gp_input_scale
=
1.
/
tf
.
sqrt
(
gp_kernel_scale
)
self
.
gp_feature_scale
=
tf
.
sqrt
(
2.
/
float
(
num_inducing
))
self
.
scale_random_features
=
scale_random_features
self
.
return_random_features
=
return_random_features
self
.
return_gp_cov
=
return_gp_cov
self
.
gp_kernel_type
=
gp_kernel_type
self
.
gp_kernel_scale
=
gp_kernel_scale
self
.
gp_output_bias
=
gp_output_bias
self
.
gp_kernel_scale_trainable
=
gp_kernel_scale_trainable
self
.
gp_output_bias_trainable
=
gp_output_bias_trainable
self
.
use_custom_random_features
=
use_custom_random_features
self
.
custom_random_features_initializer
=
custom_random_features_initializer
self
.
custom_random_features_activation
=
custom_random_features_activation
self
.
l2_regularization
=
l2_regularization
self
.
gp_output_kwargs
=
gp_output_kwargs
self
.
gp_cov_momentum
=
gp_cov_momentum
self
.
gp_cov_ridge_penalty
=
gp_cov_ridge_penalty
self
.
gp_cov_likelihood
=
gp_cov_likelihood
if
self
.
use_custom_random_features
:
# Default to Gaussian RBF kernel.
self
.
random_features_bias_initializer
=
tf
.
random_uniform_initializer
(
minval
=
0.
,
maxval
=
2.
*
math
.
pi
)
if
self
.
custom_random_features_initializer
is
None
:
self
.
custom_random_features_initializer
=
(
tf
.
keras
.
initializers
.
RandomNormal
(
stddev
=
1.
))
if
self
.
custom_random_features_activation
is
None
:
self
.
custom_random_features_activation
=
tf
.
math
.
cos
def
build
(
self
,
input_shape
):
# Defines model layers.
if
self
.
normalize_input
:
self
.
_input_norm_layer
=
tf
.
keras
.
layers
.
LayerNormalization
(
name
=
'gp_input_normalization'
)
self
.
_input_norm_layer
.
build
(
input_shape
)
input_shape
=
self
.
_input_norm_layer
.
compute_output_shape
(
input_shape
)
self
.
_random_feature
=
self
.
_make_random_feature_layer
(
name
=
'gp_random_feature'
)
self
.
_random_feature
.
build
(
input_shape
)
input_shape
=
self
.
_random_feature
.
compute_output_shape
(
input_shape
)
if
self
.
return_gp_cov
:
self
.
_gp_cov_layer
=
LaplaceRandomFeatureCovariance
(
momentum
=
self
.
gp_cov_momentum
,
ridge_penalty
=
self
.
gp_cov_ridge_penalty
,
likelihood
=
self
.
gp_cov_likelihood
,
dtype
=
self
.
dtype
,
name
=
'gp_covariance'
)
self
.
_gp_cov_layer
.
build
(
input_shape
)
self
.
_gp_output_layer
=
tf
.
keras
.
layers
.
Dense
(
units
=
self
.
units
,
use_bias
=
False
,
kernel_regularizer
=
tf
.
keras
.
regularizers
.
l2
(
self
.
l2_regularization
),
dtype
=
self
.
dtype
,
name
=
'gp_output_weights'
,
**
self
.
gp_output_kwargs
)
self
.
_gp_output_layer
.
build
(
input_shape
)
self
.
_gp_output_bias
=
tf
.
Variable
(
initial_value
=
[
self
.
gp_output_bias
]
*
self
.
units
,
dtype
=
self
.
dtype
,
trainable
=
self
.
gp_output_bias_trainable
,
name
=
'gp_output_bias'
)
self
.
built
=
True
def
_make_random_feature_layer
(
self
,
name
):
"""Defines random feature layer depending on kernel type."""
if
not
self
.
use_custom_random_features
:
# Use default RandomFourierFeatures layer from tf.keras.
return
tf
.
keras
.
layers
.
experimental
.
RandomFourierFeatures
(
output_dim
=
self
.
num_inducing
,
kernel_initializer
=
self
.
gp_kernel_type
,
scale
=
self
.
gp_kernel_scale
,
trainable
=
self
.
gp_kernel_scale_trainable
,
dtype
=
self
.
dtype
,
name
=
name
)
if
self
.
gp_kernel_type
.
lower
()
==
'linear'
:
custom_random_feature_layer
=
tf
.
keras
.
layers
.
Lambda
(
lambda
x
:
x
,
name
=
name
)
else
:
# Use user-supplied configurations.
custom_random_feature_layer
=
tf
.
keras
.
layers
.
Dense
(
units
=
self
.
num_inducing
,
use_bias
=
True
,
activation
=
self
.
custom_random_features_activation
,
kernel_initializer
=
self
.
custom_random_features_initializer
,
bias_initializer
=
self
.
random_features_bias_initializer
,
trainable
=
False
,
name
=
name
)
return
custom_random_feature_layer
def
reset_covariance_matrix
(
self
):
"""Resets covariance matrix of the GP layer.
This function is useful for reseting the model's covariance matrix at the
begining of a new epoch.
"""
self
.
_gp_cov_layer
.
reset_precision_matrix
()
def
call
(
self
,
inputs
,
global_step
=
None
,
training
=
None
):
# Computes random features.
gp_inputs
=
inputs
if
self
.
normalize_input
:
gp_inputs
=
self
.
_input_norm_layer
(
gp_inputs
)
elif
self
.
use_custom_random_features
:
# Supports lengthscale for custom random feature layer by directly
# rescaling the input.
gp_input_scale
=
tf
.
cast
(
self
.
gp_input_scale
,
inputs
.
dtype
)
gp_inputs
=
gp_inputs
*
gp_input_scale
gp_feature
=
self
.
_random_feature
(
gp_inputs
)
if
self
.
scale_random_features
:
# Scale random feature by 2. / sqrt(num_inducing) following [1].
# When using GP layer as the output layer of a nerual network,
# it is recommended to turn this scaling off to prevent it from changing
# the learning rate to the hidden layers.
gp_feature_scale
=
tf
.
cast
(
self
.
gp_feature_scale
,
inputs
.
dtype
)
gp_feature
=
gp_feature
*
gp_feature_scale
# Computes posterior center (i.e., MAP estimate) and variance.
gp_output
=
self
.
_gp_output_layer
(
gp_feature
)
+
self
.
_gp_output_bias
if
self
.
return_gp_cov
:
gp_covmat
=
self
.
_gp_cov_layer
(
gp_feature
,
gp_output
,
training
)
# Assembles model output.
model_output
=
[
gp_output
,]
if
self
.
return_gp_cov
:
model_output
.
append
(
gp_covmat
)
if
self
.
return_random_features
:
model_output
.
append
(
gp_feature
)
return
model_output
class
LaplaceRandomFeatureCovariance
(
tf
.
keras
.
layers
.
Layer
):
"""Computes the Gaussian Process covariance using Laplace method.
At training time, this layer updates the Gaussian process posterior using
model features in minibatches.
Attributes:
momentum: (float) A discount factor used to compute the moving average for
posterior precision matrix. Analogous to the momentum factor in batch
normalization. If -1 then update covariance matrix using a naive sum
without momentum, which is desirable if the goal is to compute the exact
covariance matrix by passing through data once (say in the final epoch).
ridge_penalty: (float) Initial Ridge penalty to weight covariance matrix.
This value is used to stablize the eigenvalues of weight covariance
estimate so that the matrix inverse can be computed for Cov = inv(t(X) * X
+ s * I). The ridge factor s cannot be too large since otherwise it will
dominate the t(X) * X term and make covariance estimate not meaningful.
likelihood: (str) The likelihood to use for computing Laplace approximation
for the covariance matrix. Can be one of ('binary_logistic', 'poisson',
'gaussian').
"""
def
__init__
(
self
,
momentum
=
0.999
,
ridge_penalty
=
1e-6
,
likelihood
=
'gaussian'
,
dtype
=
None
,
name
=
'laplace_covariance'
):
if
likelihood
not
in
_SUPPORTED_LIKELIHOOD
:
raise
ValueError
(
f
'"likelihood" must be one of
{
_SUPPORTED_LIKELIHOOD
}
, got
{
likelihood
}
.'
)
self
.
ridge_penalty
=
ridge_penalty
self
.
momentum
=
momentum
self
.
likelihood
=
likelihood
super
(
LaplaceRandomFeatureCovariance
,
self
).
__init__
(
dtype
=
dtype
,
name
=
name
)
def
compute_output_shape
(
self
,
input_shape
):
gp_feature_dim
=
input_shape
[
-
1
]
return
tf
.
TensorShape
([
gp_feature_dim
,
gp_feature_dim
])
def
build
(
self
,
input_shape
):
gp_feature_dim
=
input_shape
[
-
1
]
# Convert gp_feature_dim to int value for TF1 compatibility.
if
isinstance
(
gp_feature_dim
,
tf
.
compat
.
v1
.
Dimension
):
gp_feature_dim
=
gp_feature_dim
.
value
# Posterior precision matrix for the GP's random feature coefficients.
self
.
initial_precision_matrix
=
(
self
.
ridge_penalty
*
tf
.
eye
(
gp_feature_dim
,
dtype
=
self
.
dtype
))
self
.
precision_matrix
=
(
self
.
add_weight
(
name
=
'gp_precision_matrix'
,
shape
=
(
gp_feature_dim
,
gp_feature_dim
),
dtype
=
self
.
dtype
,
initializer
=
tf
.
keras
.
initializers
.
Identity
(
self
.
ridge_penalty
),
trainable
=
False
,
aggregation
=
tf
.
VariableAggregation
.
ONLY_FIRST_REPLICA
))
self
.
built
=
True
def
make_precision_matrix_update_op
(
self
,
gp_feature
,
logits
,
precision_matrix
):
"""Defines update op for the precision matrix of feature weights."""
if
self
.
likelihood
!=
'gaussian'
:
if
logits
is
None
:
raise
ValueError
(
f
'"logits" cannot be None when likelihood=
{
self
.
likelihood
}
'
)
if
logits
.
shape
[
-
1
]
!=
1
:
raise
ValueError
(
f
'likelihood=
{
self
.
likelihood
}
only support univariate logits.'
f
'Got logits dimension:
{
logits
.
shape
[
-
1
]
}
'
)
batch_size
=
tf
.
shape
(
gp_feature
)[
0
]
batch_size
=
tf
.
cast
(
batch_size
,
dtype
=
gp_feature
.
dtype
)
# Computes batch-specific normalized precision matrix.
if
self
.
likelihood
==
'binary_logistic'
:
prob
=
tf
.
sigmoid
(
logits
)
prob_multiplier
=
prob
*
(
1.
-
prob
)
elif
self
.
likelihood
==
'poisson'
:
prob_multiplier
=
tf
.
exp
(
logits
)
else
:
prob_multiplier
=
1.
gp_feature_adjusted
=
tf
.
sqrt
(
prob_multiplier
)
*
gp_feature
precision_matrix_minibatch
=
tf
.
matmul
(
gp_feature_adjusted
,
gp_feature_adjusted
,
transpose_a
=
True
)
# Updates the population-wise precision matrix.
if
self
.
momentum
>
0
:
# Use moving-average updates to accumulate batch-specific precision
# matrices.
precision_matrix_minibatch
=
precision_matrix_minibatch
/
batch_size
precision_matrix_new
=
(
self
.
momentum
*
precision_matrix
+
(
1.
-
self
.
momentum
)
*
precision_matrix_minibatch
)
else
:
# Compute exact population-wise covariance without momentum.
# If use this option, make sure to pass through data only once.
precision_matrix_new
=
precision_matrix
+
precision_matrix_minibatch
# Returns the update op.
return
precision_matrix
.
assign
(
precision_matrix_new
)
def
reset_precision_matrix
(
self
):
"""Resets precision matrix to its initial value.
This function is useful for reseting the model's covariance matrix at the
begining of a new epoch.
"""
precision_matrix_reset_op
=
self
.
precision_matrix
.
assign
(
self
.
initial_precision_matrix
)
self
.
add_update
(
precision_matrix_reset_op
)
def
compute_predictive_covariance
(
self
,
gp_feature
):
"""Computes posterior predictive variance.
Approximates the Gaussian process posterior using random features.
Given training random feature Phi_tr (num_train, num_hidden) and testing
random feature Phi_ts (batch_size, num_hidden). The predictive covariance
matrix is computed as (assuming Gaussian likelihood):
s * Phi_ts @ inv(t(Phi_tr) * Phi_tr + s * I) @ t(Phi_ts),
where s is the ridge factor to be used for stablizing the inverse, and I is
the identity matrix with shape (num_hidden, num_hidden).
Args:
gp_feature: (tf.Tensor) The random feature of testing data to be used for
computing the covariance matrix. Shape (batch_size, gp_hidden_size).
Returns:
(tf.Tensor) Predictive covariance matrix, shape (batch_size, batch_size).
"""
# Computes the covariance matrix of the feature coefficient.
feature_cov_matrix
=
tf
.
linalg
.
inv
(
self
.
precision_matrix
)
# Computes the covariance matrix of the gp prediction.
cov_feature_product
=
tf
.
matmul
(
feature_cov_matrix
,
gp_feature
,
transpose_b
=
True
)
*
self
.
ridge_penalty
gp_cov_matrix
=
tf
.
matmul
(
gp_feature
,
cov_feature_product
)
return
gp_cov_matrix
def
_get_training_value
(
self
,
training
=
None
):
if
training
is
None
:
training
=
tf
.
keras
.
backend
.
learning_phase
()
if
isinstance
(
training
,
int
):
training
=
bool
(
training
)
return
training
def
call
(
self
,
inputs
,
logits
=
None
,
training
=
None
):
"""Minibatch updates the GP's posterior precision matrix estimate.
Args:
inputs: (tf.Tensor) GP random features, shape (batch_size,
gp_hidden_size).
logits: (tf.Tensor) Pre-activation output from the model. Needed
for Laplace approximation under a non-Gaussian likelihood.
training: (tf.bool) whether or not the layer is in training mode. If in
training mode, the gp_weight covariance is updated using gp_feature.
Returns:
gp_stddev (tf.Tensor): GP posterior predictive variance,
shape (batch_size, batch_size).
"""
batch_size
=
tf
.
shape
(
inputs
)[
0
]
training
=
self
.
_get_training_value
(
training
)
if
training
:
# Define and register the update op for feature precision matrix.
precision_matrix_update_op
=
self
.
make_precision_matrix_update_op
(
gp_feature
=
inputs
,
logits
=
logits
,
precision_matrix
=
self
.
precision_matrix
)
self
.
add_update
(
precision_matrix_update_op
)
# Return null estimate during training.
return
tf
.
eye
(
batch_size
,
dtype
=
self
.
dtype
)
else
:
# Return covariance estimate during inference.
return
self
.
compute_predictive_covariance
(
gp_feature
=
inputs
)
official/nlp/modeling/layers/gaussian_process_test.py
0 → 100644
View file @
1e9cbdce
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Lint as: python3
"""Tests for Gaussian process functions."""
import
os
import
shutil
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
gaussian_process
def
exact_gaussian_kernel
(
x1
,
x2
):
"""Computes exact Gaussian kernel value(s) for tensors x1 and x2."""
x1_squared
=
tf
.
reduce_sum
(
tf
.
square
(
x1
),
list
(
range
(
1
,
len
(
x1
.
shape
))))
x2_squared
=
tf
.
reduce_sum
(
tf
.
square
(
x2
),
list
(
range
(
1
,
len
(
x2
.
shape
))))
square
=
(
x1_squared
[:,
tf
.
newaxis
]
+
x2_squared
[
tf
.
newaxis
,
:]
-
2
*
tf
.
matmul
(
x1
,
x2
,
transpose_b
=
True
))
return
tf
.
math
.
exp
(
-
square
/
2.
)
def
_generate_normal_data
(
num_sample
,
num_dim
,
loc
):
"""Generates random data sampled from i.i.d. normal distribution."""
return
np
.
random
.
normal
(
size
=
(
num_sample
,
num_dim
),
loc
=
loc
,
scale
=
1.
/
np
.
sqrt
(
num_dim
))
def
_generate_rbf_data
(
x_data
,
orthogonal
=
True
):
"""Generates high-dim data that is the eigen components of a RBF kernel."""
k_rbf
=
exact_gaussian_kernel
(
x_data
,
x_data
)
x_orth
,
x_diag
,
_
=
np
.
linalg
.
svd
(
k_rbf
)
if
orthogonal
:
return
x_orth
return
np
.
diag
(
np
.
sqrt
(
x_diag
)).
dot
(
x_orth
.
T
)
def
_make_minibatch_iterator
(
data_numpy
,
batch_size
,
num_epoch
):
"""Makes a tf.data.Dataset for given batch size and num epoches."""
dataset
=
tf
.
data
.
Dataset
.
from_tensor_slices
(
data_numpy
)
dataset
=
dataset
.
repeat
(
num_epoch
).
batch
(
batch_size
)
return
iter
(
dataset
)
def
_compute_posterior_kernel
(
x_tr
,
x_ts
,
kernel_func
,
ridge_penalty
):
"""Computes the posterior covariance matrix of a Gaussian process."""
num_sample
=
x_tr
.
shape
[
0
]
k_tt_inv
=
tf
.
linalg
.
inv
(
kernel_func
(
x_tr
,
x_tr
)
+
ridge_penalty
*
np
.
eye
(
num_sample
))
k_ts
=
kernel_func
(
x_tr
,
x_ts
)
k_ss
=
kernel_func
(
x_ts
,
x_ts
)
return
k_ss
-
tf
.
matmul
(
k_ts
,
tf
.
matmul
(
k_tt_inv
,
k_ts
),
transpose_a
=
True
)
class
GaussianProcessTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
(
GaussianProcessTest
,
self
).
setUp
()
self
.
num_data_dim
=
10
self
.
num_inducing
=
1024
self
.
num_train_sample
=
1024
self
.
num_test_sample
=
256
self
.
prec_tolerance
=
{
'atol'
:
1e-3
,
'rtol'
:
5e-2
}
self
.
cov_tolerance
=
{
'atol'
:
5e-2
,
'rtol'
:
2.
}
self
.
rbf_kern_func
=
exact_gaussian_kernel
self
.
x_tr
=
_generate_normal_data
(
self
.
num_train_sample
,
self
.
num_data_dim
,
loc
=
0.
)
self
.
x_ts
=
_generate_normal_data
(
self
.
num_test_sample
,
self
.
num_data_dim
,
loc
=
1.
)
def
test_layer_build
(
self
):
"""Tests if layer.built=True after building."""
rfgp_model
=
gaussian_process
.
RandomFeatureGaussianProcess
(
units
=
1
)
rfgp_model
.
build
(
input_shape
=
self
.
x_tr
.
shape
)
self
.
assertTrue
(
rfgp_model
.
built
)
@
parameterized
.
named_parameters
((
'rbf_data'
,
False
),
(
'orthogonal_data'
,
True
))
def
test_laplace_covariance_minibatch
(
self
,
generate_orthogonal_data
):
"""Tests if model correctly learns population-lvel precision matrix."""
batch_size
=
50
epochs
=
1000
x_data
=
_generate_rbf_data
(
self
.
x_ts
,
generate_orthogonal_data
)
data_iterator
=
_make_minibatch_iterator
(
x_data
,
batch_size
,
epochs
)
# Estimates precision matrix using minibatch.
cov_estimator
=
gaussian_process
.
LaplaceRandomFeatureCovariance
(
momentum
=
0.999
,
ridge_penalty
=
0
)
for
minibatch_data
in
data_iterator
:
_
=
cov_estimator
(
minibatch_data
,
training
=
True
)
# Evaluation
prec_mat_expected
=
x_data
.
T
.
dot
(
x_data
)
prec_mat_computed
=
(
cov_estimator
.
precision_matrix
.
numpy
()
*
self
.
num_test_sample
)
np
.
testing
.
assert_allclose
(
prec_mat_computed
,
prec_mat_expected
,
**
self
.
prec_tolerance
)
def
test_random_feature_prior_approximation
(
self
):
"""Tests random feature GP's ability in approximating exact GP prior."""
num_inducing
=
10240
rfgp_model
=
gaussian_process
.
RandomFeatureGaussianProcess
(
units
=
1
,
num_inducing
=
num_inducing
,
normalize_input
=
False
,
gp_kernel_type
=
'gaussian'
,
return_random_features
=
True
)
# Extract random features.
_
,
_
,
gp_feature
=
rfgp_model
(
self
.
x_tr
,
training
=
True
)
gp_feature_np
=
gp_feature
.
numpy
()
prior_kernel_computed
=
gp_feature_np
.
dot
(
gp_feature_np
.
T
)
prior_kernel_expected
=
self
.
rbf_kern_func
(
self
.
x_tr
,
self
.
x_tr
)
np
.
testing
.
assert_allclose
(
prior_kernel_computed
,
prior_kernel_expected
,
**
self
.
cov_tolerance
)
def
test_random_feature_posterior_approximation
(
self
):
"""Tests random feature GP's ability in approximating exact GP posterior."""
# Set momentum = 0.5 so posterior precision matrix is 0.5 * (I + K).
gp_cov_momentum
=
0.5
gp_cov_ridge_penalty
=
1.
num_inducing
=
1024
rfgp_model
=
gaussian_process
.
RandomFeatureGaussianProcess
(
units
=
1
,
num_inducing
=
num_inducing
,
normalize_input
=
False
,
gp_kernel_type
=
'gaussian'
,
gp_cov_momentum
=
gp_cov_momentum
,
gp_cov_ridge_penalty
=
gp_cov_ridge_penalty
)
# Computes posterior covariance on test data.
_
,
_
=
rfgp_model
(
self
.
x_tr
,
training
=
True
)
_
,
gp_cov_ts
=
rfgp_model
(
self
.
x_ts
,
training
=
False
)
# Scale up covariance estimate since prec matrix is down-scaled by momentum.
post_kernel_computed
=
gp_cov_ts
*
gp_cov_momentum
post_kernel_expected
=
_compute_posterior_kernel
(
self
.
x_tr
,
self
.
x_ts
,
self
.
rbf_kern_func
,
gp_cov_ridge_penalty
)
np
.
testing
.
assert_allclose
(
post_kernel_computed
,
post_kernel_expected
,
**
self
.
cov_tolerance
)
def
test_random_feature_linear_kernel
(
self
):
"""Tests if linear kernel indeed leads to an identity mapping."""
# Specify linear kernel
gp_kernel_type
=
'linear'
normalize_input
=
False
scale_random_features
=
False
use_custom_random_features
=
True
rfgp_model
=
gaussian_process
.
RandomFeatureGaussianProcess
(
units
=
1
,
normalize_input
=
normalize_input
,
gp_kernel_type
=
gp_kernel_type
,
scale_random_features
=
scale_random_features
,
use_custom_random_features
=
use_custom_random_features
,
return_random_features
=
True
)
_
,
_
,
gp_feature
=
rfgp_model
(
self
.
x_tr
,
training
=
True
)
# Check if linear kernel leads to identity mapping.
np
.
testing
.
assert_allclose
(
gp_feature
,
self
.
x_tr
,
**
self
.
prec_tolerance
)
def
test_no_matrix_update_during_test
(
self
):
"""Tests if the precision matrix is not updated during testing."""
rfgp_model
=
gaussian_process
.
RandomFeatureGaussianProcess
(
units
=
1
)
# Training.
_
,
gp_covmat_null
=
rfgp_model
(
self
.
x_tr
,
training
=
True
)
precision_mat_before_test
=
rfgp_model
.
_gp_cov_layer
.
precision_matrix
# Testing.
_
=
rfgp_model
(
self
.
x_ts
,
training
=
False
)
precision_mat_after_test
=
rfgp_model
.
_gp_cov_layer
.
precision_matrix
self
.
assertAllClose
(
gp_covmat_null
,
tf
.
eye
(
self
.
num_train_sample
),
atol
=
1e-4
)
self
.
assertAllClose
(
precision_mat_before_test
,
precision_mat_after_test
,
atol
=
1e-4
)
def
test_state_saving_and_loading
(
self
):
"""Tests if the loaded model returns same results."""
input_data
=
np
.
random
.
random
((
1
,
2
))
rfgp_model
=
gaussian_process
.
RandomFeatureGaussianProcess
(
units
=
1
)
inputs
=
tf
.
keras
.
Input
((
2
,),
batch_size
=
1
)
outputs
=
rfgp_model
(
inputs
)
model
=
tf
.
keras
.
Model
(
inputs
,
outputs
)
gp_output
,
gp_covmat
=
model
.
predict
(
input_data
)
# Save and then load the model.
temp_dir
=
self
.
get_temp_dir
()
self
.
addCleanup
(
shutil
.
rmtree
,
temp_dir
)
saved_model_dir
=
os
.
path
.
join
(
temp_dir
,
'rfgp_model'
)
model
.
save
(
saved_model_dir
)
new_model
=
tf
.
keras
.
models
.
load_model
(
saved_model_dir
)
gp_output_new
,
gp_covmat_new
=
new_model
.
predict
(
input_data
)
self
.
assertAllClose
(
gp_output
,
gp_output_new
,
atol
=
1e-4
)
self
.
assertAllClose
(
gp_covmat
,
gp_covmat_new
,
atol
=
1e-4
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/nlp/modeling/layers/spectral_normalization.py
0 → 100644
View file @
1e9cbdce
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Normalization layers.
## References:
[1] Yuichi Yoshida, Takeru Miyato. Spectral Norm Regularization for Improving
the Generalizability of Deep Learning.
_arXiv preprint arXiv:1705.10941_, 2017. https://arxiv.org/abs/1705.10941
[2] Takeru Miyato, Toshiki Kataoka, Masanori Koyama, Yuichi Yoshida.
Spectral normalization for generative adversarial networks.
In _International Conference on Learning Representations_, 2018.
[3] Henry Gouk, Eibe Frank, Bernhard Pfahringer, Michael Cree.
Regularisation of neural networks by enforcing lipschitz continuity.
_arXiv preprint arXiv:1804.04368_, 2018. https://arxiv.org/abs/1804.04368
"""
import
numpy
as
np
import
tensorflow
as
tf
class
SpectralNormalization
(
tf
.
keras
.
layers
.
Wrapper
):
"""Implements spectral normalization for Dense layer."""
def
__init__
(
self
,
layer
,
iteration
=
1
,
norm_multiplier
=
0.95
,
training
=
True
,
aggregation
=
tf
.
VariableAggregation
.
MEAN
,
inhere_layer_name
=
False
,
**
kwargs
):
"""Initializer.
Args:
layer: (tf.keras.layers.Layer) A TF Keras layer to apply normalization to.
iteration: (int) The number of power iteration to perform to estimate
weight matrix's singular value.
norm_multiplier: (float) Multiplicative constant to threshold the
normalization. Usually under normalization, the singular value will
converge to this value.
training: (bool) Whether to perform power iteration to update the singular
value estimate.
aggregation: (tf.VariableAggregation) Indicates how a distributed variable
will be aggregated. Accepted values are constants defined in the class
tf.VariableAggregation.
inhere_layer_name: (bool) Whether to inhere the name of the input layer.
**kwargs: (dict) Other keyword arguments for the layers.Wrapper class.
"""
self
.
iteration
=
iteration
self
.
do_power_iteration
=
training
self
.
aggregation
=
aggregation
self
.
norm_multiplier
=
norm_multiplier
# Set layer name.
wrapper_name
=
kwargs
.
pop
(
'name'
,
None
)
if
inhere_layer_name
:
wrapper_name
=
layer
.
name
if
not
isinstance
(
layer
,
tf
.
keras
.
layers
.
Layer
):
raise
ValueError
(
'`layer` must be a `tf.keras.layer.Layer`. '
'Observed `{}`'
.
format
(
layer
))
super
(
SpectralNormalization
,
self
).
__init__
(
layer
,
name
=
wrapper_name
,
**
kwargs
)
def
build
(
self
,
input_shape
):
super
(
SpectralNormalization
,
self
).
build
(
input_shape
)
self
.
layer
.
kernel
.
_aggregation
=
self
.
aggregation
# pylint: disable=protected-access
self
.
_dtype
=
self
.
layer
.
kernel
.
dtype
self
.
w
=
self
.
layer
.
kernel
self
.
w_shape
=
self
.
w
.
shape
.
as_list
()
self
.
uv_initializer
=
tf
.
initializers
.
random_normal
()
self
.
v
=
self
.
add_weight
(
shape
=
(
1
,
np
.
prod
(
self
.
w_shape
[:
-
1
])),
initializer
=
self
.
uv_initializer
,
trainable
=
False
,
name
=
'v'
,
dtype
=
self
.
dtype
,
aggregation
=
self
.
aggregation
)
self
.
u
=
self
.
add_weight
(
shape
=
(
1
,
self
.
w_shape
[
-
1
]),
initializer
=
self
.
uv_initializer
,
trainable
=
False
,
name
=
'u'
,
dtype
=
self
.
dtype
,
aggregation
=
self
.
aggregation
)
self
.
update_weights
()
def
call
(
self
,
inputs
,
*
,
training
=
None
):
training
=
self
.
do_power_iteration
if
training
is
None
else
training
u_update_op
,
v_update_op
,
w_update_op
=
self
.
update_weights
(
training
=
training
)
output
=
self
.
layer
(
inputs
)
w_restore_op
=
self
.
restore_weights
()
# Register update ops.
self
.
add_update
(
u_update_op
)
self
.
add_update
(
v_update_op
)
self
.
add_update
(
w_update_op
)
self
.
add_update
(
w_restore_op
)
return
output
def
update_weights
(
self
,
*
,
training
=
True
):
w_reshaped
=
tf
.
reshape
(
self
.
w
,
[
-
1
,
self
.
w_shape
[
-
1
]])
u_hat
=
self
.
u
v_hat
=
self
.
v
if
training
:
for
_
in
range
(
self
.
iteration
):
v_hat
=
tf
.
nn
.
l2_normalize
(
tf
.
matmul
(
u_hat
,
tf
.
transpose
(
w_reshaped
)))
u_hat
=
tf
.
nn
.
l2_normalize
(
tf
.
matmul
(
v_hat
,
w_reshaped
))
sigma
=
tf
.
matmul
(
tf
.
matmul
(
v_hat
,
w_reshaped
),
tf
.
transpose
(
u_hat
))
# Convert sigma from a 1x1 matrix to a scalar.
sigma
=
tf
.
reshape
(
sigma
,
[])
u_update_op
=
self
.
u
.
assign
(
u_hat
)
v_update_op
=
self
.
v
.
assign
(
v_hat
)
# Bound spectral norm to be not larger than self.norm_multiplier.
w_norm
=
tf
.
cond
((
self
.
norm_multiplier
/
sigma
)
<
1
,
lambda
:
# pylint:disable=g-long-lambda
(
self
.
norm_multiplier
/
sigma
)
*
self
.
w
,
lambda
:
self
.
w
)
w_update_op
=
self
.
layer
.
kernel
.
assign
(
w_norm
)
return
u_update_op
,
v_update_op
,
w_update_op
def
restore_weights
(
self
):
"""Restores layer weights to maintain gradient update (See Alg 1 of [1])."""
return
self
.
layer
.
kernel
.
assign
(
self
.
w
)
class
SpectralNormalizationConv2D
(
tf
.
keras
.
layers
.
Wrapper
):
"""Implements spectral normalization for Conv2D layer based on [3]."""
def
__init__
(
self
,
layer
,
iteration
=
1
,
norm_multiplier
=
0.95
,
training
=
True
,
aggregation
=
tf
.
VariableAggregation
.
MEAN
,
legacy_mode
=
False
,
**
kwargs
):
"""Initializer.
Args:
layer: (tf.keras.layers.Layer) A TF Keras layer to apply normalization to.
iteration: (int) The number of power iteration to perform to estimate
weight matrix's singular value.
norm_multiplier: (float) Multiplicative constant to threshold the
normalization. Usually under normalization, the singular value will
converge to this value.
training: (bool) Whether to perform power iteration to update the singular
value estimate.
aggregation: (tf.VariableAggregation) Indicates how a distributed variable
will be aggregated. Accepted values are constants defined in the class
tf.VariableAggregation.
legacy_mode: (bool) Whether to use the legacy implementation where the
dimension of the u and v vectors are set to the batch size. It should
not be enabled unless for backward compatibility reasons.
**kwargs: (dict) Other keyword arguments for the layers.Wrapper class.
"""
self
.
iteration
=
iteration
self
.
do_power_iteration
=
training
self
.
aggregation
=
aggregation
self
.
norm_multiplier
=
norm_multiplier
self
.
legacy_mode
=
legacy_mode
# Set layer attributes.
layer
.
_name
+=
'_spec_norm'
if
not
isinstance
(
layer
,
tf
.
keras
.
layers
.
Conv2D
):
raise
ValueError
(
'layer must be a `tf.keras.layer.Conv2D` instance. You passed: {input}'
.
format
(
input
=
layer
))
super
(
SpectralNormalizationConv2D
,
self
).
__init__
(
layer
,
**
kwargs
)
def
build
(
self
,
input_shape
):
self
.
layer
.
build
(
input_shape
)
self
.
layer
.
kernel
.
_aggregation
=
self
.
aggregation
# pylint: disable=protected-access
self
.
_dtype
=
self
.
layer
.
kernel
.
dtype
# Shape (kernel_size_1, kernel_size_2, in_channel, out_channel).
self
.
w
=
self
.
layer
.
kernel
self
.
w_shape
=
self
.
w
.
shape
.
as_list
()
self
.
strides
=
self
.
layer
.
strides
# Set the dimensions of u and v vectors.
batch_size
=
input_shape
[
0
]
uv_dim
=
batch_size
if
self
.
legacy_mode
else
1
# Resolve shapes.
in_height
=
input_shape
[
1
]
in_width
=
input_shape
[
2
]
in_channel
=
self
.
w_shape
[
2
]
out_height
=
in_height
//
self
.
strides
[
0
]
out_width
=
in_width
//
self
.
strides
[
1
]
out_channel
=
self
.
w_shape
[
3
]
self
.
in_shape
=
(
uv_dim
,
in_height
,
in_width
,
in_channel
)
self
.
out_shape
=
(
uv_dim
,
out_height
,
out_width
,
out_channel
)
self
.
uv_initializer
=
tf
.
initializers
.
random_normal
()
self
.
v
=
self
.
add_weight
(
shape
=
self
.
in_shape
,
initializer
=
self
.
uv_initializer
,
trainable
=
False
,
name
=
'v'
,
dtype
=
self
.
dtype
,
aggregation
=
self
.
aggregation
)
self
.
u
=
self
.
add_weight
(
shape
=
self
.
out_shape
,
initializer
=
self
.
uv_initializer
,
trainable
=
False
,
name
=
'u'
,
dtype
=
self
.
dtype
,
aggregation
=
self
.
aggregation
)
super
(
SpectralNormalizationConv2D
,
self
).
build
()
def
call
(
self
,
inputs
):
u_update_op
,
v_update_op
,
w_update_op
=
self
.
update_weights
()
output
=
self
.
layer
(
inputs
)
w_restore_op
=
self
.
restore_weights
()
# Register update ops.
self
.
add_update
(
u_update_op
)
self
.
add_update
(
v_update_op
)
self
.
add_update
(
w_update_op
)
self
.
add_update
(
w_restore_op
)
return
output
def
update_weights
(
self
):
"""Computes power iteration for convolutional filters based on [3]."""
# Initialize u, v vectors.
u_hat
=
self
.
u
v_hat
=
self
.
v
if
self
.
do_power_iteration
:
for
_
in
range
(
self
.
iteration
):
# Updates v.
v_
=
tf
.
nn
.
conv2d_transpose
(
u_hat
,
self
.
w
,
output_shape
=
self
.
in_shape
,
strides
=
self
.
strides
,
padding
=
'SAME'
)
v_hat
=
tf
.
nn
.
l2_normalize
(
tf
.
reshape
(
v_
,
[
1
,
-
1
]))
v_hat
=
tf
.
reshape
(
v_hat
,
v_
.
shape
)
# Updates u.
u_
=
tf
.
nn
.
conv2d
(
v_hat
,
self
.
w
,
strides
=
self
.
strides
,
padding
=
'SAME'
)
u_hat
=
tf
.
nn
.
l2_normalize
(
tf
.
reshape
(
u_
,
[
1
,
-
1
]))
u_hat
=
tf
.
reshape
(
u_hat
,
u_
.
shape
)
v_w_hat
=
tf
.
nn
.
conv2d
(
v_hat
,
self
.
w
,
strides
=
self
.
strides
,
padding
=
'SAME'
)
sigma
=
tf
.
matmul
(
tf
.
reshape
(
v_w_hat
,
[
1
,
-
1
]),
tf
.
reshape
(
u_hat
,
[
-
1
,
1
]))
# Convert sigma from a 1x1 matrix to a scalar.
sigma
=
tf
.
reshape
(
sigma
,
[])
u_update_op
=
self
.
u
.
assign
(
u_hat
)
v_update_op
=
self
.
v
.
assign
(
v_hat
)
w_norm
=
tf
.
cond
((
self
.
norm_multiplier
/
sigma
)
<
1
,
lambda
:
# pylint:disable=g-long-lambda
(
self
.
norm_multiplier
/
sigma
)
*
self
.
w
,
lambda
:
self
.
w
)
w_update_op
=
self
.
layer
.
kernel
.
assign
(
w_norm
)
return
u_update_op
,
v_update_op
,
w_update_op
def
restore_weights
(
self
):
"""Restores layer weights to maintain gradient update (See Alg 1 of [1])."""
return
self
.
layer
.
kernel
.
assign
(
self
.
w
)
official/nlp/modeling/layers/spectral_normalization_test.py
0 → 100644
View file @
1e9cbdce
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for normalization layers.
## References:
[1] Hanie Sedghi, Vineet Gupta, Philip M. Long.
The Singular Values of Convolutional Layers.
In _International Conference on Learning Representations_, 2019.
"""
from
absl.testing
import
parameterized
import
numpy
as
np
import
tensorflow
as
tf
from
official.nlp.modeling.layers
import
spectral_normalization
DenseLayer
=
tf
.
keras
.
layers
.
Dense
(
10
)
Conv2DLayer
=
tf
.
keras
.
layers
.
Conv2D
(
filters
=
64
,
kernel_size
=
3
,
padding
=
'valid'
)
def
_compute_spectral_norm
(
weight
):
if
weight
.
ndim
>
2
:
# Computes Conv2D via FFT transform as in [1].
weight
=
np
.
fft
.
fft2
(
weight
,
weight
.
shape
[
1
:
3
],
axes
=
[
0
,
1
])
return
np
.
max
(
np
.
linalg
.
svd
(
weight
,
compute_uv
=
False
))
class
NormalizationTest
(
tf
.
test
.
TestCase
,
parameterized
.
TestCase
):
def
setUp
(
self
):
super
(
NormalizationTest
,
self
).
setUp
()
self
.
num_iterations
=
1000
self
.
norm_multiplier
=
0.95
@
parameterized
.
named_parameters
(
(
'Dense'
,
(
None
,
10
),
DenseLayer
,
spectral_normalization
.
SpectralNormalization
),
(
'Conv2D'
,
(
None
,
32
,
32
,
3
),
Conv2DLayer
,
spectral_normalization
.
SpectralNormalizationConv2D
))
def
test_spec_norm_magnitude
(
self
,
input_shape
,
layer
,
norm_wrapper
):
"""Tests if the weights spectral norm converges to norm_multiplier."""
layer
.
build
(
input_shape
)
sn_layer
=
norm_wrapper
(
layer
,
iteration
=
self
.
num_iterations
,
norm_multiplier
=
self
.
norm_multiplier
)
# Perform normalization.
sn_layer
.
build
(
input_shape
)
sn_layer
.
update_weights
()
normalized_kernel
=
sn_layer
.
layer
.
kernel
.
numpy
()
spectral_norm_computed
=
_compute_spectral_norm
(
normalized_kernel
)
spectral_norm_expected
=
self
.
norm_multiplier
self
.
assertAllClose
(
spectral_norm_computed
,
spectral_norm_expected
,
atol
=
5e-2
)
# Test that the normalized layer is K-Lipschitz. In particular, if the layer
# is a function f, then ||f(x1) - f(x2)||_2 <= K * ||(x1 - x2)||_2, where K
# is the norm multiplier.
new_input_shape
=
(
16
,)
+
input_shape
[
1
:]
new_input
=
tf
.
random
.
uniform
(
new_input_shape
)
delta_vec
=
tf
.
random
.
uniform
(
new_input_shape
)
output1
=
sn_layer
(
new_input
)
output2
=
sn_layer
(
new_input
+
delta_vec
)
delta_input
=
tf
.
norm
(
tf
.
reshape
(
delta_vec
,
(
-
1
,))).
numpy
()
delta_output
=
tf
.
norm
(
tf
.
reshape
(
output2
-
output1
,
(
-
1
,))).
numpy
()
self
.
assertLessEqual
(
delta_output
,
self
.
norm_multiplier
*
delta_input
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment