Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
77a275b1
Commit
77a275b1
authored
Sep 12, 2021
by
Hongkun Yu
Committed by
A. Unique TensorFlower
Sep 12, 2021
Browse files
Internal change
PiperOrigin-RevId: 396239953
parent
f8418c2d
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
602 additions
and
0 deletions
+602
-0
official/nlp/projects/example/README.md
official/nlp/projects/example/README.md
+116
-0
official/nlp/projects/example/classification_data_loader.py
official/nlp/projects/example/classification_data_loader.py
+81
-0
official/nlp/projects/example/classification_example.py
official/nlp/projects/example/classification_example.py
+187
-0
official/nlp/projects/example/classification_example_test.py
official/nlp/projects/example/classification_example_test.py
+69
-0
official/nlp/projects/example/experiments/classification_ft_cola.yaml
.../projects/example/experiments/classification_ft_cola.yaml
+41
-0
official/nlp/projects/example/experiments/local_example.yaml
official/nlp/projects/example/experiments/local_example.yaml
+40
-0
official/nlp/projects/example/train.py
official/nlp/projects/example/train.py
+68
-0
No files found.
official/nlp/projects/example/README.md
0 → 100644
View file @
77a275b1
# NLP example project
This is a tutorial for setting up your project using TF-NLP library. Here we
focus on the scaffolding of project and pay little attention to any modeling
aspects.
Below we use classification as an example.
## Setup your codebase
First you need to define the
[
Task
](
https://github.com/tensorflow/models/blob/master/official/core/base_task.py
)
by inheirting it. Task is an abstraction of any machine learning task, here we
focus on two things inputs and optimization target.
NOTE: We use BertClassifier as base model. You can shop other models
[
here
](
https://github.com/tensorflow/models/blob/master/official/nlp/modeling/models
)
.
#### Step 1: build\_inputs
Here we use
[
CoLA
](
https://nyu-mll.github.io/CoLA/
)
, a binary classification
task as an example.
TODO(saberkun): Add demo data instructions.
There are 4 fields we care about in the tf.Example, input_ids, input_mask,
segment_ids and label_ids. Then we start with a simple data loader by inheriting
the
[
DataLoader
](
https://github.com/tensorflow/models/blob/master/official/nlp/data/data_loader.py
)
interface.
```
python
class
ClassificationDataLoader
(
data_loader
.
DataLoader
):
...
def
_parse
(
self
,
record
:
Mapping
[
str
,
tf
.
Tensor
]):
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
x
=
{
'input_word_ids'
:
record
[
'input_ids'
],
'input_mask'
:
record
[
'input_mask'
],
'input_type_ids'
:
record
[
'segment_ids'
]
}
y
=
record
[
'label_ids'
]
return
(
x
,
y
)
...
```
Overall, loader will translate the tf.Example to approiate format for model to
consume. Then in Task.build_inputs, link the dataset like
```
python
def
build_inputs
(
self
):
...
loader
=
classification_data_loader
.
ClassificationDataLoader
(
params
)
return
loader
.
load
(
input_context
)
```
#### Step 2: build\_losses
We use standard cross entropy loss and make sure the
`build_losses()`
returns a
float scalar Tensor.
```
python
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
):
loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels
,
tf
.
cast
(
model_outputs
,
tf
.
float32
),
from_logits
=
True
)
...
```
#### Try the workflow locally.
We use a small BERT model for local trial and error. Below is the command:
```
shell
# Assume you are under official/nlp/projects.
python3 example/train.py
\
--experiment
=
example_bert_classification_example
\
--config_file
=
example/local_example.yaml
\
--mode
=
train
\
--model_dir
=
/tmp/example_project_test/
```
The train binary translates the config file for the experiments. Usually you may
just change the task import logics:
```
python
task_config
=
classification_example
.
ClassificationExampleConfig
()
task
=
classification_example
.
ClassificationExampleTask
(
task_config
)
```
TIPs: You can also check the
[
unittest
](
https://github.com/tensorflow/models/blob/master/official/nlp/projects/example/classification_example_test.py
)
for better understanding.
### Finetune
TF-NLP make it easy to start from a
[
pretrained checkpoint
](
https://github.com/tensorflow/models/blob/master/official/nlp/docs/pretrained_models.md
)
,
try below. This is done through configuring task.init_checkpoint in the YAML
config below, see the
[
base_task.initialize
](
https://github.com/tensorflow/models/blob/master/official/core/base_task.py
)
method for more details.
We use GCP TPU to demonstrate this.
```
shell
EXP_NAME
=
bert_base_cola
EXP_TYPE
=
example_bert_classification_example
CONFIG_FILE
=
example/experiments/classification_ft_cola.yaml
TPU_NAME
=
experiment01
MODEL_DIR
=
your GCS bucket folder
python3 example/train.py
\
--experiment
=
$EXP_TYPE
\
--mode
=
train_and_eval
\
--tpu
=
$TPU_NAME
\
--model_dir
=
${
MODEL_DIR
}
--config_file
=
${
CONFIG_FILE
}
```
official/nlp/projects/example/classification_data_loader.py
0 → 100644
View file @
77a275b1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Loads dataset for classification tasks."""
from
typing
import
Dict
,
Mapping
,
Optional
,
Tuple
import
dataclasses
import
tensorflow
as
tf
from
official.core
import
config_definitions
as
cfg
from
official.core
import
input_reader
from
official.nlp.data
import
data_loader
@
dataclasses
.
dataclass
class
ClassificationExampleDataConfig
(
cfg
.
DataConfig
):
"""Data config for token classification task."""
seq_length
:
int
=
128
class
ClassificationDataLoader
(
data_loader
.
DataLoader
):
"""A class to load dataset for sentence prediction (classification) task."""
def
__init__
(
self
,
params
):
self
.
_params
=
params
self
.
_seq_length
=
params
.
seq_length
def
_decode
(
self
,
record
:
tf
.
Tensor
)
->
Dict
[
str
,
tf
.
Tensor
]:
"""Decodes a serialized tf.Example."""
name_to_features
=
{
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'label_ids'
:
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
),
}
example
=
tf
.
io
.
parse_single_example
(
record
,
name_to_features
)
# tf.Example only supports tf.int64, but the TPU only supports tf.int32.
# So cast all int64 to int32.
for
name
in
example
:
t
=
example
[
name
]
if
t
.
dtype
==
tf
.
int64
:
t
=
tf
.
cast
(
t
,
tf
.
int32
)
example
[
name
]
=
t
return
example
def
_parse
(
self
,
record
:
Mapping
[
str
,
tf
.
Tensor
])
->
Tuple
[
Dict
[
str
,
tf
.
Tensor
],
tf
.
Tensor
]:
"""Parses raw tensors into a dict of tensors to be consumed by the model."""
x
=
{
'input_word_ids'
:
record
[
'input_ids'
],
'input_mask'
:
record
[
'input_mask'
],
'input_type_ids'
:
record
[
'segment_ids'
]
}
y
=
record
[
'label_ids'
]
return
(
x
,
y
)
def
load
(
self
,
input_context
:
Optional
[
tf
.
distribute
.
InputContext
]
=
None
)
->
tf
.
data
.
Dataset
:
"""Returns a tf.dataset.Dataset."""
reader
=
input_reader
.
InputReader
(
params
=
self
.
_params
,
decoder_fn
=
self
.
_decode
,
parser_fn
=
self
.
_parse
)
return
reader
.
read
(
input_context
)
official/nlp/projects/example/classification_example.py
0 → 100644
View file @
77a275b1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Classifcation Task Showcase."""
import
dataclasses
from
typing
import
List
,
Mapping
,
Text
from
seqeval
import
metrics
as
seqeval_metrics
import
tensorflow
as
tf
from
official.core
import
base_task
from
official.core
import
config_definitions
as
cfg
from
official.core
import
exp_factory
from
official.modeling
import
optimization
from
official.modeling
import
tf_utils
from
official.modeling.hyperparams
import
base_config
from
official.nlp.configs
import
encoders
from
official.nlp.modeling
import
models
from
official.nlp.projects.example
import
classification_data_loader
from
official.nlp.tasks
import
utils
@
dataclasses
.
dataclass
class
ModelConfig
(
base_config
.
Config
):
"""A base span labeler configuration."""
encoder
:
encoders
.
EncoderConfig
=
encoders
.
EncoderConfig
()
head_dropout
:
float
=
0.1
head_initializer_range
:
float
=
0.02
@
dataclasses
.
dataclass
class
ClassificationExampleConfig
(
cfg
.
TaskConfig
):
"""The model config."""
# At most one of `init_checkpoint` and `hub_module_url` can be specified.
init_checkpoint
:
str
=
''
hub_module_url
:
str
=
''
model
:
ModelConfig
=
ModelConfig
()
num_classes
=
2
class_names
=
[
'A'
,
'B'
]
train_data
:
cfg
.
DataConfig
=
classification_data_loader
.
ClassificationExampleDataConfig
(
)
validation_data
:
cfg
.
DataConfig
=
classification_data_loader
.
ClassificationExampleDataConfig
(
)
class
ClassificationExampleTask
(
base_task
.
Task
):
"""Task object for classification."""
def
build_model
(
self
)
->
tf
.
keras
.
Model
:
if
self
.
task_config
.
hub_module_url
and
self
.
task_config
.
init_checkpoint
:
raise
ValueError
(
'At most one of `hub_module_url` and '
'`init_checkpoint` can be specified.'
)
if
self
.
task_config
.
hub_module_url
:
encoder_network
=
utils
.
get_encoder_from_hub
(
self
.
task_config
.
hub_module_url
)
else
:
encoder_network
=
encoders
.
build_encoder
(
self
.
task_config
.
model
.
encoder
)
return
models
.
BertClassifier
(
network
=
encoder_network
,
num_classes
=
len
(
self
.
task_config
.
class_names
),
initializer
=
tf
.
keras
.
initializers
.
TruncatedNormal
(
stddev
=
self
.
task_config
.
model
.
head_initializer_range
),
dropout_rate
=
self
.
task_config
.
model
.
head_dropout
)
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
loss
=
tf
.
keras
.
losses
.
sparse_categorical_crossentropy
(
labels
,
tf
.
cast
(
model_outputs
,
tf
.
float32
),
from_logits
=
True
)
return
tf_utils
.
safe_mean
(
loss
)
def
build_inputs
(
self
,
params
:
cfg
.
DataConfig
,
input_context
=
None
)
->
tf
.
data
.
Dataset
:
"""Returns tf.data.Dataset for sentence_prediction task."""
loader
=
classification_data_loader
.
ClassificationDataLoader
(
params
)
return
loader
.
load
(
input_context
)
def
inference_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
)
->
Mapping
[
str
,
tf
.
Tensor
]:
"""Performs the forward step."""
logits
=
model
(
inputs
,
training
=
False
)
return
{
'logits'
:
logits
,
'predict_ids'
:
tf
.
argmax
(
logits
,
axis
=-
1
,
output_type
=
tf
.
int32
)
}
def
validation_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
metrics
=
None
)
->
Mapping
[
str
,
tf
.
Tensor
]:
"""Validatation step.
Args:
inputs: a dictionary of input tensors.
model: the keras.Model.
metrics: a nested structure of metrics objects.
Returns:
A dictionary of logs.
"""
features
,
labels
=
inputs
outputs
=
self
.
inference_step
(
features
,
model
)
loss
=
self
.
build_losses
(
labels
=
labels
,
model_outputs
=
outputs
[
'logits'
])
# Negative label ids are padding labels which should be ignored.
real_label_index
=
tf
.
where
(
tf
.
greater_equal
(
labels
,
0
))
predict_ids
=
tf
.
gather_nd
(
outputs
[
'predict_ids'
],
real_label_index
)
label_ids
=
tf
.
gather_nd
(
labels
,
real_label_index
)
return
{
self
.
loss
:
loss
,
'predict_ids'
:
predict_ids
,
'label_ids'
:
label_ids
,
}
def
aggregate_logs
(
self
,
state
=
None
,
step_outputs
=
None
)
->
Mapping
[
Text
,
List
[
List
[
Text
]]]:
"""Aggregates over logs returned from a validation step."""
if
state
is
None
:
state
=
{
'predict_class'
:
[],
'label_class'
:
[]}
def
id_to_class_name
(
batched_ids
):
class_names
=
[]
for
per_example_ids
in
batched_ids
:
class_names
.
append
([])
for
per_token_id
in
per_example_ids
.
numpy
().
tolist
():
class_names
[
-
1
].
append
(
self
.
task_config
.
class_names
[
per_token_id
])
return
class_names
# Convert id to class names, because `seqeval_metrics` relies on the class
# name to decide IOB tags.
state
[
'predict_class'
].
extend
(
id_to_class_name
(
step_outputs
[
'predict_ids'
]))
state
[
'label_class'
].
extend
(
id_to_class_name
(
step_outputs
[
'label_ids'
]))
return
state
def
reduce_aggregated_logs
(
self
,
aggregated_logs
,
global_step
=
None
)
->
Mapping
[
Text
,
float
]:
"""Reduces aggregated logs over validation steps."""
label_class
=
aggregated_logs
[
'label_class'
]
predict_class
=
aggregated_logs
[
'predict_class'
]
return
{
'f1'
:
seqeval_metrics
.
f1_score
(
label_class
,
predict_class
),
'precision'
:
seqeval_metrics
.
precision_score
(
label_class
,
predict_class
),
'recall'
:
seqeval_metrics
.
recall_score
(
label_class
,
predict_class
),
'accuracy'
:
seqeval_metrics
.
accuracy_score
(
label_class
,
predict_class
),
}
@
exp_factory
.
register_config_factory
(
'example_bert_classification_example'
)
def
bert_classification_example
()
->
cfg
.
ExperimentConfig
:
"""Return a minimum experiment config for Bert token classification."""
return
cfg
.
ExperimentConfig
(
task
=
ClassificationExampleConfig
(),
trainer
=
cfg
.
TrainerConfig
(
optimizer_config
=
optimization
.
OptimizationConfig
({
'optimizer'
:
{
'type'
:
'adamw'
,
},
'learning_rate'
:
{
'type'
:
'polynomial'
,
},
'warmup'
:
{
'type'
:
'polynomial'
}
})),
restrictions
=
[
'task.train_data.is_training != None'
,
'task.validation_data.is_training != None'
])
official/nlp/projects/example/classification_example_test.py
0 → 100644
View file @
77a275b1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for nlp.projects.example.classification_example."""
import
tensorflow
as
tf
from
official.core
import
config_definitions
as
cfg
from
official.nlp.configs
import
encoders
from
official.nlp.projects.example
import
classification_data_loader
from
official.nlp.projects.example
import
classification_example
class
ClassificationExampleTest
(
tf
.
test
.
TestCase
):
def
get_model_config
(
self
):
return
classification_example
.
ModelConfig
(
encoder
=
encoders
.
EncoderConfig
(
bert
=
encoders
.
BertEncoderConfig
(
vocab_size
=
30522
,
num_layers
=
2
)))
def
get_dummy_dataset
(
self
,
params
:
cfg
.
DataConfig
):
def
dummy_data
(
_
):
dummy_ids
=
tf
.
zeros
((
1
,
params
.
seq_length
),
dtype
=
tf
.
int32
)
x
=
dict
(
input_word_ids
=
dummy_ids
,
input_mask
=
dummy_ids
,
input_type_ids
=
dummy_ids
)
y
=
tf
.
zeros
((
1
,
1
),
dtype
=
tf
.
int32
)
return
(
x
,
y
)
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
dataset
.
repeat
()
dataset
=
dataset
.
map
(
dummy_data
,
num_parallel_calls
=
tf
.
data
.
experimental
.
AUTOTUNE
)
return
dataset
def
test_task_with_dummy_data
(
self
):
train_data_config
=
(
classification_data_loader
.
ClassificationExampleDataConfig
(
input_path
=
'dummy'
,
seq_length
=
128
,
global_batch_size
=
1
))
task_config
=
classification_example
.
ClassificationExampleConfig
(
model
=
self
.
get_model_config
(),)
task
=
classification_example
.
ClassificationExampleTask
(
task_config
)
task
.
build_inputs
=
self
.
get_dummy_dataset
model
=
task
.
build_model
()
metrics
=
task
.
build_metrics
()
dataset
=
task
.
build_inputs
(
train_data_config
)
iterator
=
iter
(
dataset
)
optimizer
=
tf
.
keras
.
optimizers
.
SGD
(
lr
=
0.1
)
task
.
initialize
(
model
)
task
.
train_step
(
next
(
iterator
),
model
,
optimizer
,
metrics
=
metrics
)
if
__name__
==
'__main__'
:
tf
.
test
.
main
()
official/nlp/projects/example/experiments/classification_ft_cola.yaml
0 → 100644
View file @
77a275b1
task
:
model
:
encoder
:
type
:
bert
bert
:
attention_dropout_rate
:
0.1
dropout_rate
:
0.1
hidden_activation
:
gelu
hidden_size
:
768
initializer_range
:
0.02
intermediate_size
:
3072
max_position_embeddings
:
512
num_attention_heads
:
12
num_layers
:
12
type_vocab_size
:
2
vocab_size
:
30522
init_checkpoint
:
'
BERT
checkpoint'
train_data
:
input_path
:
'
YourData/COLA_train.tf_record'
is_training
:
true
global_batch_size
:
32
validation_data
:
input_path
:
'
YourData/COLA_eval.tf_record'
is_training
:
false
global_batch_size
:
32
trainer
:
checkpoint_interval
:
5000
max_to_keep
:
5
steps_per_loop
:
100
summary_interval
:
100
train_steps
:
10000
validation_interval
:
100
validation_steps
:
-1
optimizer_config
:
learning_rate
:
polynomial
:
initial_learning_rate
:
0.00002
decay_steps
:
10000
warmup
:
polynomial
:
warmup_steps
:
100
official/nlp/projects/example/experiments/local_example.yaml
0 → 100644
View file @
77a275b1
task
:
model
:
encoder
:
type
:
bert
bert
:
attention_dropout_rate
:
0.1
dropout_rate
:
0.1
hidden_activation
:
gelu
hidden_size
:
288
initializer_range
:
0.02
intermediate_size
:
256
max_position_embeddings
:
512
num_attention_heads
:
6
num_layers
:
2
type_vocab_size
:
4
vocab_size
:
114507
train_data
:
input_path
:
'
YourData/COLA_train.tf_record'
is_training
:
true
global_batch_size
:
32
validation_data
:
input_path
:
'
YourData/COLA_eval.tf_record'
is_training
:
false
global_batch_size
:
32
trainer
:
checkpoint_interval
:
500
max_to_keep
:
5
steps_per_loop
:
100
summary_interval
:
100
train_steps
:
500
validation_interval
:
100
validation_steps
:
-1
optimizer_config
:
learning_rate
:
polynomial
:
initial_learning_rate
:
0.001
decay_steps
:
740000
warmup
:
polynomial
:
warmup_steps
:
100
official/nlp/projects/example/train.py
0 → 100644
View file @
77a275b1
# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""A customized training library for the specific task."""
from
absl
import
app
from
absl
import
flags
import
gin
from
official.common
import
distribute_utils
from
official.common
import
flags
as
tfm_flags
from
official.core
import
train_lib
from
official.core
import
train_utils
from
official.modeling
import
performance
from
official.nlp.projects.example
import
classification_example
FLAGS
=
flags
.
FLAGS
def
main
(
_
):
gin
.
parse_config_files_and_bindings
(
FLAGS
.
gin_file
,
FLAGS
.
gin_params
)
params
=
train_utils
.
parse_configuration
(
FLAGS
)
model_dir
=
FLAGS
.
model_dir
if
'train'
in
FLAGS
.
mode
:
# Pure eval modes do not output yaml files. Otherwise continuous eval job
# may race against the train job for writing the same file.
train_utils
.
serialize_config
(
params
,
model_dir
)
# Sets mixed_precision policy. Using 'mixed_float16' or 'mixed_bfloat16'
# can have significant impact on model speeds by utilizing float16 in case of
# GPUs, and bfloat16 in the case of TPUs. loss_scale takes effect only when
# dtype is float16
if
params
.
runtime
.
mixed_precision_dtype
:
performance
.
set_mixed_precision_policy
(
params
.
runtime
.
mixed_precision_dtype
)
distribution_strategy
=
distribute_utils
.
get_distribution_strategy
(
distribution_strategy
=
params
.
runtime
.
distribution_strategy
,
all_reduce_alg
=
params
.
runtime
.
all_reduce_alg
,
num_gpus
=
params
.
runtime
.
num_gpus
,
tpu_address
=
params
.
runtime
.
tpu
,
**
params
.
runtime
.
model_parallelism
())
with
distribution_strategy
.
scope
():
task
=
classification_example
.
ClassificationExampleTask
(
params
.
task
)
train_lib
.
run_experiment
(
distribution_strategy
=
distribution_strategy
,
task
=
task
,
mode
=
FLAGS
.
mode
,
params
=
params
,
model_dir
=
model_dir
)
train_utils
.
save_gin_config
(
FLAGS
.
mode
,
model_dir
)
if
__name__
==
'__main__'
:
tfm_flags
.
define_flags
()
app
.
run
(
main
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment