Unverified Commit 1d6e71e1 authored by Patrick von Platen's avatar Patrick von Platen Committed by GitHub
Browse files

[EncoderDecoder] Add Cross Attention for GPT2 (#6415)



* add cross attention layers for gpt2

* make gpt2 cross attention work

* finish bert2gpt2

* add explicit comments

* remove attention mask since not yet supported

* revert attn mask in pipeline

* Update src/transformers/modeling_gpt2.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/modeling_encoder_decoder.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent eb613b56
...@@ -372,11 +372,16 @@ class GenerationMixin: ...@@ -372,11 +372,16 @@ class GenerationMixin:
if self.config.is_encoder_decoder: if self.config.is_encoder_decoder:
if decoder_start_token_id is None: if decoder_start_token_id is None:
decoder_start_token_id = bos_token_id # see if BOS token can be used for decoder_start_token_id
if bos_token_id is not None:
decoder_start_token_id = bos_token_id
elif hasattr(self.config, "decoder") and hasattr(self.config.decoder, "bos_token_id"):
decoder_start_token_id = self.config.decoder.bos_token_id
else:
raise ValueError(
"decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
)
assert (
decoder_start_token_id is not None
), "decoder_start_token_id or bos_token_id has to be defined for encoder-decoder generation"
assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self) assert hasattr(self, "get_encoder"), "{} should have a 'get_encoder' function defined".format(self)
assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder) assert callable(self.get_encoder), "{} should be a method".format(self.get_encoder)
......
...@@ -287,6 +287,8 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -287,6 +287,8 @@ class EncoderDecoderModel(PreTrainedModel):
**kwargs_decoder, **kwargs_decoder,
) )
# TODO(PVP): currently it is not possible to use `past`
# with the encoder/decoder framework -> should be implemented
return decoder_outputs + encoder_outputs return decoder_outputs + encoder_outputs
def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs): def prepare_inputs_for_generation(self, input_ids, past, attention_mask, **kwargs):
...@@ -299,15 +301,24 @@ class EncoderDecoderModel(PreTrainedModel): ...@@ -299,15 +301,24 @@ class EncoderDecoderModel(PreTrainedModel):
encoder_outputs = (past,) encoder_outputs = (past,)
decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids) decoder_inputs = self.decoder.prepare_inputs_for_generation(input_ids)
decoder_attention_mask = decoder_inputs["attention_mask"] if "attention_mask" in decoder_inputs else None
return { input_dict = {
"attention_mask": attention_mask, "attention_mask": attention_mask,
"decoder_attention_mask": decoder_inputs["attention_mask"], "decoder_attention_mask": decoder_attention_mask,
"decoder_input_ids": decoder_inputs["input_ids"], "decoder_input_ids": decoder_inputs["input_ids"],
"encoder_outputs": encoder_outputs, "encoder_outputs": encoder_outputs,
} }
# Ideally all models should have a `use_cache`
# leave following to ifs until all have it implemented
if "use_cache" in decoder_inputs:
input_dict["decoder_use_cache"] = decoder_inputs["use_cache"]
if "past_key_values" in decoder_inputs:
input_dict["decoder_past_key_values"] = decoder_inputs["past_key_values"]
return input_dict
def _reorder_cache(self, past, beam_idx): def _reorder_cache(self, past, beam_idx):
# as a default encoder-decoder models do not re-order the past. # apply decoder cache reordering here
# TODO(PVP): might have to be updated, e.g. if GPT2 is to be used as a decoder return self.decoder._reorder_cache(past, beam_idx)
return past
...@@ -118,7 +118,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path): ...@@ -118,7 +118,7 @@ def load_tf_weights_in_gpt2(model, config, gpt2_checkpoint_path):
class Attention(nn.Module): class Attention(nn.Module):
def __init__(self, nx, n_ctx, config, scale=False): def __init__(self, nx, n_ctx, config, scale=False, is_cross_attention=False):
super().__init__() super().__init__()
n_state = nx # in Attention: n_state=768 (nx=n_embd) n_state = nx # in Attention: n_state=768 (nx=n_embd)
...@@ -131,8 +131,12 @@ class Attention(nn.Module): ...@@ -131,8 +131,12 @@ class Attention(nn.Module):
self.n_head = config.n_head self.n_head = config.n_head
self.split_size = n_state self.split_size = n_state
self.scale = scale self.scale = scale
self.is_cross_attention = is_cross_attention
self.c_attn = Conv1D(n_state * 3, nx) if self.is_cross_attention:
self.c_attn = Conv1D(2 * n_state, nx)
self.q_attn = Conv1D(n_state, nx)
else:
self.c_attn = Conv1D(3 * n_state, nx)
self.c_proj = Conv1D(n_state, nx) self.c_proj = Conv1D(n_state, nx)
self.attn_dropout = nn.Dropout(config.attn_pdrop) self.attn_dropout = nn.Dropout(config.attn_pdrop)
self.resid_dropout = nn.Dropout(config.resid_pdrop) self.resid_dropout = nn.Dropout(config.resid_pdrop)
...@@ -160,8 +164,11 @@ class Attention(nn.Module): ...@@ -160,8 +164,11 @@ class Attention(nn.Module):
if self.scale: if self.scale:
w = w / (float(v.size(-1)) ** 0.5) w = w / (float(v.size(-1)) ** 0.5)
nd, ns = w.size(-2), w.size(-1) nd, ns = w.size(-2), w.size(-1)
mask = self.bias[:, :, ns - nd : ns, :ns]
w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype)) if not self.is_cross_attention:
# if only "normal" attention layer implements causal mask
mask = self.bias[:, :, ns - nd : ns, :ns]
w = torch.where(mask.bool(), w, self.masked_bias.to(w.dtype))
if attention_mask is not None: if attention_mask is not None:
# Apply the attention mask # Apply the attention mask
...@@ -193,10 +200,26 @@ class Attention(nn.Module): ...@@ -193,10 +200,26 @@ class Attention(nn.Module):
return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features) return x.permute(0, 2, 1, 3) # (batch, head, seq_length, head_features)
def forward( def forward(
self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False self,
hidden_states,
layer_past=None,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=False,
output_attentions=False,
): ):
x = self.c_attn(x) if encoder_hidden_states is not None:
query, key, value = x.split(self.split_size, dim=2) assert hasattr(
self, "q_attn"
), "If class is used as cross attention, the weights `q_attn` have to be defined. Please make sure to instantiate class with `Attention(..., is_cross_attention=True)`."
query = self.q_attn(hidden_states)
key, value = self.c_attn(encoder_hidden_states).split(self.split_size, dim=2)
attention_mask = encoder_attention_mask
else:
query, key, value = self.c_attn(hidden_states).split(self.split_size, dim=2)
query = self.split_heads(query) query = self.split_heads(query)
key = self.split_heads(key, k=True) key = self.split_heads(key, k=True)
value = self.split_heads(value) value = self.split_heads(value)
...@@ -239,32 +262,64 @@ class MLP(nn.Module): ...@@ -239,32 +262,64 @@ class MLP(nn.Module):
class Block(nn.Module): class Block(nn.Module):
def __init__(self, n_ctx, config, scale=False): def __init__(self, n_ctx, config, scale=False):
super().__init__() super().__init__()
nx = config.n_embd hidden_size = config.n_embd
inner_dim = config.n_inner if config.n_inner is not None else 4 * nx inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
self.ln_1 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = Attention(nx, n_ctx, config, scale) self.attn = Attention(hidden_size, n_ctx, config, scale)
self.ln_2 = nn.LayerNorm(nx, eps=config.layer_norm_epsilon) self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
if config.add_cross_attention:
self.crossattention = Attention(hidden_size, n_ctx, config, scale, is_cross_attention=True)
self.ln_cross_attn = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = MLP(inner_dim, config) self.mlp = MLP(inner_dim, config)
def forward( def forward(
self, x, layer_past=None, attention_mask=None, head_mask=None, use_cache=False, output_attentions=False, self,
hidden_states,
layer_past=None,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=False,
output_attentions=False,
): ):
output_attn = self.attn( attn_outputs = self.attn(
self.ln_1(x), self.ln_1(hidden_states),
layer_past=layer_past, layer_past=layer_past,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask, head_mask=head_mask,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
a = output_attn[0] # output_attn: a, present, (attentions) attn_output = attn_outputs[0] # output_attn: a, present, (attentions)
outputs = attn_outputs[1:]
# residual connection
hidden_states = attn_output + hidden_states
if encoder_hidden_states is not None:
# add one self-attention block for cross-attention
assert hasattr(
self, "crossattention"
), f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers by setting `config.add_cross_attention=True`"
cross_attn_outputs = self.crossattention(
self.ln_cross_attn(hidden_states),
attention_mask=attention_mask,
head_mask=head_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_attentions=output_attentions,
)
attn_output = cross_attn_outputs[0]
# residual connection
hidden_states = hidden_states + attn_output
outputs = outputs + cross_attn_outputs[1:] # add cross attentions if we output attention weights
x = x + a feed_forward_hidden_states = self.mlp(self.ln_2(hidden_states))
m = self.mlp(self.ln_2(x)) # residual connection
x = x + m hidden_states = hidden_states + feed_forward_hidden_states
outputs = [x] + output_attn[1:] outputs = [hidden_states] + outputs
return outputs # x, present, (attentions) return outputs # hidden_states, present, (cross_attentions, attentions)
class GPT2PreTrainedModel(PreTrainedModel): class GPT2PreTrainedModel(PreTrainedModel):
...@@ -449,6 +504,8 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -449,6 +504,8 @@ class GPT2Model(GPT2PreTrainedModel):
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
output_hidden_states=None, output_hidden_states=None,
...@@ -506,7 +563,7 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -506,7 +563,7 @@ class GPT2Model(GPT2PreTrainedModel):
# So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length] # So we can broadcast to [batch_size, num_heads, from_seq_length, to_seq_length]
# this attention mask is more simple than the triangular masking of causal attention # this attention mask is more simple than the triangular masking of causal attention
# used in OpenAI GPT, we just need to prepare the broadcast dimension here. # used in OpenAI GPT, we just need to prepare the broadcast dimension here.
attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) attention_mask = attention_mask[:, None, None, :]
# Since attention_mask is 1.0 for positions we want to attend and 0.0 for # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
# masked positions, this operation will create a tensor which is 0.0 for # masked positions, this operation will create a tensor which is 0.0 for
...@@ -516,6 +573,17 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -516,6 +573,17 @@ class GPT2Model(GPT2PreTrainedModel):
attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility attention_mask = attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
attention_mask = (1.0 - attention_mask) * -10000.0 attention_mask = (1.0 - attention_mask) * -10000.0
# If a 2D ou 3D attention mask is provided for the cross-attention
# we need to make broadcastabe to [batch_size, num_heads, seq_length, seq_length]
if self.config.add_cross_attention and encoder_hidden_states is not None:
encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
if encoder_attention_mask is None:
encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
encoder_attention_mask = self.invert_attention_mask(encoder_attention_mask)
else:
encoder_attention_mask = None
# Prepare head mask if needed # Prepare head mask if needed
# 1.0 in head_mask indicate we keep the head # 1.0 in head_mask indicate we keep the head
# attention_probs has shape bsz x n_heads x N x N # attention_probs has shape bsz x n_heads x N x N
...@@ -546,6 +614,8 @@ class GPT2Model(GPT2PreTrainedModel): ...@@ -546,6 +614,8 @@ class GPT2Model(GPT2PreTrainedModel):
layer_past=layer_past, layer_past=layer_past,
attention_mask=attention_mask, attention_mask=attention_mask,
head_mask=head_mask[i], head_mask=head_mask[i],
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
) )
...@@ -593,17 +663,21 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -593,17 +663,21 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
def get_output_embeddings(self): def get_output_embeddings(self):
return self.lm_head return self.lm_head
def prepare_inputs_for_generation(self, input_ids, past, **kwargs): def prepare_inputs_for_generation(self, input_ids, past=None, **kwargs):
# only last token for inputs_ids if past is defined in kwargs # only last token for inputs_ids if past is defined in kwargs
if past: if past:
input_ids = input_ids[:, -1].unsqueeze(-1) input_ids = input_ids[:, -1].unsqueeze(-1)
return {"input_ids": input_ids, "past_key_values": past, "use_cache": kwargs["use_cache"]} return {
"input_ids": input_ids,
"past_key_values": past,
"use_cache": kwargs.get("use_cache"),
}
@add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING) @add_start_docstrings_to_callable(GPT2_INPUTS_DOCSTRING)
@add_code_sample_docstrings( @add_code_sample_docstrings(
tokenizer_class=_TOKENIZER_FOR_DOC, tokenizer_class=_TOKENIZER_FOR_DOC,
checkpoint="ctrl", checkpoint="gpt2",
output_type=CausalLMOutputWithPast, output_type=CausalLMOutputWithPast,
config_class=_CONFIG_FOR_DOC, config_class=_CONFIG_FOR_DOC,
) )
...@@ -616,6 +690,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -616,6 +690,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
position_ids=None, position_ids=None,
head_mask=None, head_mask=None,
inputs_embeds=None, inputs_embeds=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
labels=None, labels=None,
use_cache=None, use_cache=None,
output_attentions=None, output_attentions=None,
...@@ -648,6 +724,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel): ...@@ -648,6 +724,8 @@ class GPT2LMHeadModel(GPT2PreTrainedModel):
position_ids=position_ids, position_ids=position_ids,
head_mask=head_mask, head_mask=head_mask,
inputs_embeds=inputs_embeds, inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
use_cache=use_cache, use_cache=use_cache,
output_attentions=output_attentions, output_attentions=output_attentions,
output_hidden_states=output_hidden_states, output_hidden_states=output_hidden_states,
......
...@@ -20,10 +20,9 @@ import unittest ...@@ -20,10 +20,9 @@ import unittest
from transformers import is_torch_available from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
# TODO(PVP): this line reruns all the tests in BertModelTest; not sure whether this can be prevented
# for now only run module with pytest tests/test_modeling_encoder_decoder.py::EncoderDecoderModelTest
from .test_modeling_bert import BertModelTester from .test_modeling_bert import BertModelTester
from .test_modeling_common import ids_tensor from .test_modeling_common import ids_tensor
from .test_modeling_gpt2 import GPT2ModelTester
from .test_modeling_roberta import RobertaModelTester from .test_modeling_roberta import RobertaModelTester
...@@ -31,6 +30,7 @@ if is_torch_available(): ...@@ -31,6 +30,7 @@ if is_torch_available():
from transformers import ( from transformers import (
BertModel, BertModel,
BertLMHeadModel, BertLMHeadModel,
GPT2LMHeadModel,
RobertaModel, RobertaModel,
RobertaForCausalLM, RobertaForCausalLM,
EncoderDecoderModel, EncoderDecoderModel,
...@@ -424,3 +424,59 @@ class RoBertaEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase): ...@@ -424,3 +424,59 @@ class RoBertaEncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
def get_pretrained_model(self): def get_pretrained_model(self):
return EncoderDecoderModel.from_encoder_decoder_pretrained("roberta-base", "roberta-base") return EncoderDecoderModel.from_encoder_decoder_pretrained("roberta-base", "roberta-base")
class GPT2EncoderDecoderModelTest(EncoderDecoderMixin, unittest.TestCase):
def get_encoder_decoder_model(self, config, decoder_config):
encoder_model = BertModel(config)
decoder_model = GPT2LMHeadModel(decoder_config)
return encoder_model, decoder_model
def prepare_config_and_inputs(self):
model_tester_encoder = BertModelTester(self, batch_size=13)
model_tester_decoder = GPT2ModelTester(self, batch_size=13)
encoder_config_and_inputs = model_tester_encoder.prepare_config_and_inputs()
decoder_config_and_inputs = model_tester_decoder.prepare_config_and_inputs_for_decoder()
(
config,
input_ids,
token_type_ids,
input_mask,
sequence_labels,
token_labels,
choice_labels,
) = encoder_config_and_inputs
(
decoder_config,
decoder_input_ids,
decoder_input_mask,
decoder_head_mask,
decoder_token_type_ids,
decoder_sequence_labels,
decoder_token_labels,
decoder_choice_labels,
encoder_hidden_states,
encoder_attention_mask,
) = decoder_config_and_inputs
# make sure that cross attention layers are added
decoder_config.add_cross_attention = True
# disable cache for now
decoder_config.use_cache = False
return {
"config": config,
"input_ids": input_ids,
"attention_mask": input_mask,
"decoder_config": decoder_config,
"decoder_input_ids": decoder_input_ids,
"decoder_token_type_ids": decoder_token_type_ids,
"decoder_attention_mask": decoder_input_mask,
"decoder_sequence_labels": decoder_sequence_labels,
"decoder_token_labels": decoder_token_labels,
"decoder_choice_labels": decoder_choice_labels,
"encoder_hidden_states": encoder_hidden_states,
"labels": decoder_token_labels,
}
def get_pretrained_model(self):
return EncoderDecoderModel.from_encoder_decoder_pretrained("bert-base-cased", "gpt2")
...@@ -20,7 +20,7 @@ from transformers import is_torch_available ...@@ -20,7 +20,7 @@ from transformers import is_torch_available
from transformers.testing_utils import require_torch, slow, torch_device from transformers.testing_utils import require_torch, slow, torch_device
from .test_configuration_common import ConfigTester from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor from .test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
if is_torch_available(): if is_torch_available():
...@@ -62,27 +62,27 @@ class GPT2ModelTester: ...@@ -62,27 +62,27 @@ class GPT2ModelTester:
scope=None, scope=None,
): ):
self.parent = parent self.parent = parent
self.batch_size = 14 self.batch_size = batch_size
self.seq_length = 7 self.seq_length = seq_length
self.is_training = True self.is_training = is_training
self.use_token_type_ids = True self.use_token_type_ids = use_token_type_ids
self.use_input_mask = True self.use_input_mask = use_input_mask
self.use_labels = True self.use_labels = use_labels
self.use_mc_token_ids = True self.use_mc_token_ids = use_mc_token_ids
self.vocab_size = 99 self.vocab_size = vocab_size
self.hidden_size = 32 self.hidden_size = hidden_size
self.num_hidden_layers = 5 self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = 4 self.num_attention_heads = num_attention_heads
self.intermediate_size = 37 self.intermediate_size = intermediate_size
self.hidden_act = "gelu" self.hidden_act = hidden_act
self.hidden_dropout_prob = 0.1 self.hidden_dropout_prob = hidden_dropout_prob
self.attention_probs_dropout_prob = 0, 1 self.attention_probs_dropout_prob = attention_probs_dropout_prob
self.max_position_embeddings = 512 self.max_position_embeddings = max_position_embeddings
self.type_vocab_size = 16 self.type_vocab_size = type_vocab_size
self.type_sequence_label_size = 2 self.type_sequence_label_size = type_sequence_label_size
self.initializer_range = 0.02 self.initializer_range = initializer_range
self.num_labels = 3 self.num_labels = num_labels
self.num_choices = 4 self.num_choices = num_choices
self.scope = None self.scope = None
self.bos_token_id = vocab_size - 1 self.bos_token_id = vocab_size - 1
self.eos_token_id = vocab_size - 1 self.eos_token_id = vocab_size - 1
...@@ -142,6 +142,35 @@ class GPT2ModelTester: ...@@ -142,6 +142,35 @@ class GPT2ModelTester:
choice_labels, choice_labels,
) )
def prepare_config_and_inputs_for_decoder(self):
(
config,
input_ids,
input_mask,
head_mask,
token_type_ids,
mc_token_ids,
sequence_labels,
token_labels,
choice_labels,
) = self.prepare_config_and_inputs()
encoder_hidden_states = floats_tensor([self.batch_size, self.seq_length, self.hidden_size])
encoder_attention_mask = ids_tensor([self.batch_size, self.seq_length], vocab_size=2)
return (
config,
input_ids,
input_mask,
head_mask,
token_type_ids,
sequence_labels,
token_labels,
choice_labels,
encoder_hidden_states,
encoder_attention_mask,
)
def create_and_check_gpt2_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args): def create_and_check_gpt2_model(self, config, input_ids, input_mask, head_mask, token_type_ids, *args):
model = GPT2Model(config=config) model = GPT2Model(config=config)
model.to(torch_device) model.to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment