Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4193aa9f
Commit
4193aa9f
authored
Nov 19, 2019
by
alexzubiaga
Committed by
alexzubiaga
Nov 19, 2019
Browse files
add TFXLNetForTokenClassification implementation and unit test
add XLNetForTokenClassification implementation and unit tests
parent
f3386d93
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
232 additions
and
11 deletions
+232
-11
transformers/__init__.py
transformers/__init__.py
+5
-3
transformers/modeling_tf_xlnet.py
transformers/modeling_tf_xlnet.py
+53
-0
transformers/modeling_xlnet.py
transformers/modeling_xlnet.py
+100
-0
transformers/tests/modeling_tf_xlnet_test.py
transformers/tests/modeling_tf_xlnet_test.py
+26
-0
transformers/tests/modeling_xlnet_test.py
transformers/tests/modeling_xlnet_test.py
+48
-8
No files found.
transformers/__init__.py
View file @
4193aa9f
...
...
@@ -83,9 +83,10 @@ if is_torch_available():
CTRLLMHeadModel
,
CTRL_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_xlnet
import
(
XLNetPreTrainedModel
,
XLNetModel
,
XLNetLMHeadModel
,
XLNetForSequenceClassification
,
XLNetForMultipleChoice
,
XLNetForQuestionAnsweringSimple
,
XLNetForQuestionAnswering
,
load_tf_weights_in_xlnet
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
)
XLNetForSequenceClassification
,
XLNetForTokenClassification
,
XLNetForMultipleChoice
,
XLNetForQuestionAnsweringSimple
,
XLNetForQuestionAnswering
,
load_tf_weights_in_xlnet
,
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
)
from
.modeling_xlm
import
(
XLMPreTrainedModel
,
XLMModel
,
XLMWithLMHeadModel
,
XLMForSequenceClassification
,
XLMForQuestionAnswering
,
XLMForQuestionAnsweringSimple
,
...
...
@@ -136,6 +137,7 @@ if is_tf_available():
from
.modeling_tf_xlnet
import
(
TFXLNetPreTrainedModel
,
TFXLNetMainLayer
,
TFXLNetModel
,
TFXLNetLMHeadModel
,
TFXLNetForSequenceClassification
,
TFXLNetForTokenClassification
,
TFXLNetForQuestionAnsweringSimple
,
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
)
...
...
transformers/modeling_tf_xlnet.py
View file @
4193aa9f
...
...
@@ -939,6 +939,59 @@ class TFXLNetForSequenceClassification(TFXLNetPreTrainedModel):
return
outputs
# return logits, (mems), (hidden states), (attentions)
@
add_start_docstrings
(
"""XLNet Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """
,
XLNET_START_DOCSTRING
,
XLNET_INPUTS_DOCSTRING
)
class
TFXLNetForTokenClassification
(
TFXLNetPreTrainedModel
):
r
"""
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**scores**: ``tf.Tensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
Classification scores (before SoftMax).
**mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``tf.Tensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
See details in the docstring of the `mems` input above.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``tf.Tensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``tf.Tensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
import tensorflow as tf
from transformers import XLNetTokenizer, TFXLNetForTokenClassification
tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
model = TFXLNetForSequenceClassification.from_pretrained('xlnet-large-cased')
input_ids = tf.constant(tokenizer.encode("Hello, my dog is cute"))[None, :] # Batch size 1
outputs = model(input_ids)
scores = outputs[0]
"""
def
__init__
(
self
,
config
,
*
inputs
,
**
kwargs
):
super
(
TFXLNetForTokenClassification
,
self
).
__init__
(
config
,
*
inputs
,
**
kwargs
)
self
.
num_labels
=
config
.
num_labels
self
.
transformer
=
TFXLNetMainLayer
(
config
,
name
=
'transformer'
)
self
.
classifier
=
tf
.
keras
.
layers
.
Dense
(
config
.
num_labels
,
kernel_initializer
=
get_initializer
(
config
.
initializer_range
),
name
=
'classifier'
)
def
call
(
self
,
inputs
,
**
kwargs
):
transformer_outputs
=
self
.
transformer
(
inputs
,
**
kwargs
)
output
=
transformer_outputs
[
0
]
logits
=
self
.
classifier
(
output
)
outputs
=
(
logits
,)
+
transformer_outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
return
outputs
# return logits, (mems), (hidden states), (attentions)
# @add_start_docstrings("""XLNet Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear layers on top of
# the hidden-states output to compute `span start logits` and `span end logits`). """,
# XLNET_START_DOCSTRING, XLNET_INPUTS_DOCSTRING)
...
...
transformers/modeling_xlnet.py
View file @
4193aa9f
...
...
@@ -1046,6 +1046,106 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
@
add_start_docstrings
(
"""XLNet Model with a token classification head on top (a linear layer on top of
the hidden-states output) e.g. for Named-Entity-Recognition (NER) tasks. """
,
XLNET_START_DOCSTRING
,
XLNET_INPUTS_DOCSTRING
)
class
XLNetForTokenClassification
(
XLNetPreTrainedModel
):
r
"""
Inputs:
**input_ids**: ``torch.LongTensor`` of shape ``(batch_size, num_choices, sequence_length)``:
Indices of input sequence tokens in the vocabulary.
The second dimension of the input (`num_choices`) indicates the number of choices to scores.
**token_type_ids**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size, sequence_length)``:
Segment token indices to indicate first and second portions of the inputs.
Indices are selected in ``[0, 1]``: ``0`` corresponds to a `sentence A` token, ``1``
**attention_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length)``:
Mask to avoid performing attention on padding token indices.
Mask values selected in ``[0, 1]``:
``1`` for tokens that are NOT MASKED, ``0`` for MASKED tokens.
**head_mask**: (`optional`) ``torch.FloatTensor`` of shape ``(num_heads,)`` or ``(num_layers, num_heads)``:
Mask to nullify selected heads of the self-attention modules.
Mask values selected in ``[0, 1]``:
``1`` indicates the head is **not masked**, ``0`` indicates the head is **masked**.
**inputs_embeds**: (`optional`) ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, embedding_dim)``:
Optionally, instead of passing ``input_ids`` you can choose to directly pass an embedded representation.
This is useful if you want more control over how to convert `input_ids` indices into associated vectors
than the model's internal embedding lookup matrix.
**labels**: (`optional`) ``torch.LongTensor`` of shape ``(batch_size,)``:
Labels for computing the multiple choice classification loss.
Indices should be in ``[0, ..., num_choices]`` where `num_choices` is the size of the second dimension
of the input tensors. (see `input_ids` above)
Outputs: `Tuple` comprising various elements depending on the configuration (config) and inputs:
**loss**: (`optional`, returned when ``labels`` is provided) ``torch.FloatTensor`` of shape ``(1,)``:
Classification loss.
**scores**: ``torch.FloatTensor`` of shape ``(batch_size, sequence_length, config.num_labels)``
Classification scores (before SoftMax).
**mems**: (`optional`, returned when ``config.mem_len > 0``)
list of ``torch.FloatTensor`` (one for each layer):
that contains pre-computed hidden-states (key and values in the attention blocks) as computed by the model
if config.mem_len > 0 else tuple of None. Can be used to speed up sequential decoding and attend to longer context.
See details in the docstring of the `mems` input above.
**hidden_states**: (`optional`, returned when ``config.output_hidden_states=True``)
list of ``torch.FloatTensor`` (one for the output of each layer + the output of the embeddings)
of shape ``(batch_size, sequence_length, hidden_size)``:
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
**attentions**: (`optional`, returned when ``config.output_attentions=True``)
list of ``torch.FloatTensor`` (one for each layer) of shape ``(batch_size, num_heads, sequence_length, sequence_length)``:
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads.
Examples::
tokenizer = XLNetTokenizer.from_pretrained('xlnet-large-cased')
model = XLNetForSequenceClassification.from_pretrained('xlnet-large-cased')
input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute")).unsqueeze(0) # Batch size 1
labels = torch.tensor([1] * input_ids.size(1)).unsqueeze(0) # Batch size 1
outputs = model(input_ids, labels=labels)
scores = outputs[0]
"""
def
__init__
(
self
,
config
):
super
(
XLNetForTokenClassification
,
self
).
__init__
(
config
)
self
.
num_labels
=
config
.
num_labels
self
.
transformer
=
XLNetModel
(
config
)
self
.
classifier
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
num_labels
)
self
.
init_weights
()
def
forward
(
self
,
input_ids
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
token_type_ids
=
None
,
input_mask
=
None
,
head_mask
=
None
,
inputs_embeds
=
None
,
labels
=
None
):
outputs
=
self
.
transformer
(
input_ids
,
attention_mask
=
attention_mask
,
mems
=
mems
,
perm_mask
=
perm_mask
,
target_mapping
=
target_mapping
,
token_type_ids
=
token_type_ids
,
input_mask
=
input_mask
,
head_mask
=
head_mask
,
inputs_embeds
=
inputs_embeds
)
sequence_output
=
outputs
[
0
]
logits
=
self
.
classifier
(
sequence_output
)
outputs
=
(
logits
,)
+
outputs
[
1
:]
# Keep mems, hidden states, attentions if there are in it
if
labels
is
not
None
:
loss_fct
=
CrossEntropyLoss
()
# Only keep active parts of the loss
if
attention_mask
is
not
None
:
active_loss
=
attention_mask
.
view
(
-
1
)
==
1
active_logits
=
logits
.
view
(
-
1
,
self
.
num_labels
)[
active_loss
]
active_labels
=
labels
.
view
(
-
1
)[
active_loss
]
loss
=
loss_fct
(
active_logits
,
active_labels
)
else
:
loss
=
loss_fct
(
logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
outputs
=
(
loss
,)
+
outputs
return
outputs
# return (loss), logits, (mems), (hidden states), (attentions)
@
add_start_docstrings
(
"""XLNet Model with a multiple choice classification head on top (a linear layer on top of
the pooled output and a softmax) e.g. for RACE/SWAG tasks. """
,
XLNET_START_DOCSTRING
,
XLNET_INPUTS_DOCSTRING
)
...
...
transformers/tests/modeling_tf_xlnet_test.py
View file @
4193aa9f
...
...
@@ -30,6 +30,7 @@ if is_tf_available():
from
transformers.modeling_tf_xlnet
import
(
TFXLNetModel
,
TFXLNetLMHeadModel
,
TFXLNetForSequenceClassification
,
TFXLNetForTokenClassification
,
TFXLNetForQuestionAnsweringSimple
,
TF_XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
)
else
:
...
...
@@ -42,6 +43,7 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
all_model_classes
=
(
TFXLNetModel
,
TFXLNetLMHeadModel
,
TFXLNetForSequenceClassification
,
TFXLNetForTokenClassification
,
TFXLNetForQuestionAnsweringSimple
)
if
is_tf_available
()
else
()
test_pruning
=
False
...
...
@@ -258,6 +260,26 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
list
(
list
(
mem
.
shape
)
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_for_token_classification
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
config
.
num_labels
=
input_ids_1
.
shape
[
1
]
model
=
TFXLNetForTokenClassification
(
config
)
inputs
=
{
'input_ids'
:
input_ids_1
,
'attention_mask'
:
input_mask
,
# 'token_type_ids': token_type_ids
}
logits
,
mems_1
=
model
(
inputs
)
result
=
{
"mems_1"
:
[
mem
.
numpy
()
for
mem
in
mems_1
],
"logits"
:
logits
.
numpy
(),
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
shape
),
[
self
.
batch_size
,
self
.
seq_length
,
config
.
num_labels
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
shape
)
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
...
...
@@ -289,6 +311,10 @@ class TFXLNetModelTest(TFCommonTestCases.TFCommonModelTester):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlnet_sequence_classif
(
*
config_and_inputs
)
def
test_xlnet_token_classification
(
self
):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlnet_for_token_classification
(
*
config_and_inputs
)
def
test_xlnet_qa
(
self
):
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
...
...
transformers/tests/modeling_xlnet_test.py
View file @
4193aa9f
...
...
@@ -28,7 +28,8 @@ from transformers import is_torch_available
if
is_torch_available
():
import
torch
from
transformers
import
(
XLNetConfig
,
XLNetModel
,
XLNetLMHeadModel
,
XLNetForSequenceClassification
,
XLNetForQuestionAnswering
)
from
transformers
import
(
XLNetConfig
,
XLNetModel
,
XLNetLMHeadModel
,
XLNetForSequenceClassification
,
XLNetForTokenClassification
,
XLNetForQuestionAnswering
)
from
transformers.modeling_xlnet
import
XLNET_PRETRAINED_MODEL_ARCHIVE_MAP
else
:
pytestmark
=
pytest
.
mark
.
skip
(
"Require Torch"
)
...
...
@@ -38,7 +39,7 @@ from .configuration_common_test import ConfigTester
class
XLNetModelTest
(
CommonTestCases
.
CommonModelTester
):
all_model_classes
=
(
XLNetModel
,
XLNetLMHeadModel
,
all_model_classes
=
(
XLNetModel
,
XLNetLMHeadModel
,
XLNetForTokenClassification
,
XLNetForSequenceClassification
,
XLNetForQuestionAnswering
)
if
is_torch_available
()
else
()
test_pruning
=
False
...
...
@@ -107,10 +108,12 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
sequence_labels
=
None
lm_labels
=
None
is_impossible_labels
=
None
token_labels
=
None
if
self
.
use_labels
:
lm_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
vocab_size
)
sequence_labels
=
ids_tensor
([
self
.
batch_size
],
self
.
type_sequence_label_size
)
is_impossible_labels
=
ids_tensor
([
self
.
batch_size
],
2
).
float
()
token_labels
=
ids_tensor
([
self
.
batch_size
,
self
.
seq_length
],
self
.
type_vocab_size
)
config
=
XLNetConfig
(
vocab_size_or_config_json_file
=
self
.
vocab_size
,
...
...
@@ -129,14 +132,14 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
num_labels
=
self
.
type_sequence_label_size
)
return
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
)
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
)
def
set_seed
(
self
):
random
.
seed
(
self
.
seed
)
torch
.
manual_seed
(
self
.
seed
)
def
create_and_check_xlnet_base_model
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
):
model
=
XLNetModel
(
config
)
model
.
eval
()
...
...
@@ -164,7 +167,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_lm_head
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
):
model
=
XLNetLMHeadModel
(
config
)
model
.
eval
()
...
...
@@ -204,7 +207,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
[[
self
.
mem_len
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_qa
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
):
model
=
XLNetForQuestionAnswering
(
config
)
model
.
eval
()
...
...
@@ -261,8 +264,40 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
create_and_check_xlnet_token_classif
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
):
model
=
XLNetForTokenClassification
(
config
)
model
.
eval
()
logits
,
mems_1
=
model
(
input_ids_1
)
loss
,
logits
,
mems_1
=
model
(
input_ids_1
,
labels
=
token_labels
)
result
=
{
"loss"
:
loss
,
"mems_1"
:
mems_1
,
"logits"
:
logits
,
}
self
.
parent
.
assertListEqual
(
list
(
result
[
"loss"
].
size
()),
[])
self
.
parent
.
assertListEqual
(
list
(
result
[
"logits"
].
size
()),
[
self
.
batch_size
,
self
.
seq_length
,
self
.
type_sequence_label_size
])
self
.
parent
.
assertListEqual
(
list
(
list
(
mem
.
size
())
for
mem
in
result
[
"mems_1"
]),
[[
self
.
seq_length
,
self
.
batch_size
,
self
.
hidden_size
]]
*
self
.
num_hidden_layers
)
def
prepare_config_and_inputs_for_common
(
self
):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
return
config
,
inputs_dict
def
create_and_check_xlnet_sequence_classif
(
self
,
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
):
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
,
token_labels
):
model
=
XLNetForSequenceClassification
(
config
)
model
.
eval
()
...
...
@@ -289,7 +324,7 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
config_and_inputs
=
self
.
prepare_config_and_inputs
()
(
config
,
input_ids_1
,
input_ids_2
,
input_ids_q
,
perm_mask
,
input_mask
,
target_mapping
,
segment_ids
,
lm_labels
,
sequence_labels
,
is_impossible_labels
)
=
config_and_inputs
sequence_labels
,
is_impossible_labels
,
token_labels
)
=
config_and_inputs
inputs_dict
=
{
'input_ids'
:
input_ids_1
}
return
config
,
inputs_dict
...
...
@@ -316,6 +351,11 @@ class XLNetModelTest(CommonTestCases.CommonModelTester):
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlnet_sequence_classif
(
*
config_and_inputs
)
def
test_xlnet_token_classif
(
self
):
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
self
.
model_tester
.
create_and_check_xlnet_token_classif
(
*
config_and_inputs
)
def
test_xlnet_qa
(
self
):
self
.
model_tester
.
set_seed
()
config_and_inputs
=
self
.
model_tester
.
prepare_config_and_inputs
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment