"tests/optimization_test.py" did not exist on "158e82e061c02fc2f1613adb7ac1d1cb6adae71c"
Commit 3cf2020c authored by Rémi Louf's avatar Rémi Louf
Browse files

change kwargs processing

parent a88a0e44
...@@ -114,23 +114,28 @@ class PreTrainedEncoderDecoder(nn.Module): ...@@ -114,23 +114,28 @@ class PreTrainedEncoderDecoder(nn.Module):
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those # `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as a whole. # that apply to the model as a whole.
# We let the specific kwargs override the common ones in case of conflict. # We let the specific kwargs override the common ones in case of conflict.
kwargs_encoder = { kwargs_common = {
argument[len("encoder_"):]: value argument: value
for argument, value in kwargs.items() for argument, value in kwargs.items()
if argument.startswith("encoder_") if not argument.startswith("encoder_")
and not argument.startswith("decoder_")
} }
kwargs_decoder = { kwargs_decoder = kwargs_common.copy()
argument[len("decoder_"):]: value kwargs_encoder = kwargs_common.copy()
kwargs_encoder.update(
{
argument[len("encoder_") :]: value
for argument, value in kwargs.items() for argument, value in kwargs.items()
if argument.startswith("decoder_") if argument.startswith("encoder_")
} }
kwargs_common = { )
argument: value kwargs_decoder.update(
{
argument[len("decoder_") :]: value
for argument, value in kwargs.items() for argument, value in kwargs.items()
if not (argument.startswith("encoder_") or argument.startswith("decoder_")) if argument.startswith("decoder_")
} }
kwargs_decoder = dict(kwargs_common, **kwargs_decoder) )
kwargs_encoder = dict(kwargs_common, **kwargs_encoder)
# Load and initialize the encoder and decoder # Load and initialize the encoder and decoder
# The distinction between encoder and decoder at the model level is made # The distinction between encoder and decoder at the model level is made
...@@ -185,35 +190,44 @@ class PreTrainedEncoderDecoder(nn.Module): ...@@ -185,35 +190,44 @@ class PreTrainedEncoderDecoder(nn.Module):
# `encoder_`), decoder-specific (prefixed by `decoder_`) and those # `encoder_`), decoder-specific (prefixed by `decoder_`) and those
# that apply to the model as whole. # that apply to the model as whole.
# We let the specific kwargs override the common ones in case of conflict. # We let the specific kwargs override the common ones in case of conflict.
kwargs_encoder = { kwargs_common = {
argument[len("encoder_"):]: value argument: value
for argument, value in kwargs.items() for argument, value in kwargs.items()
if argument.startswith("encoder_") if not argument.startswith("encoder_")
and not argument.startswith("decoder_")
} }
kwargs_decoder = { kwargs_decoder = kwargs_common.copy()
argument[len("decoder_"):]: value kwargs_encoder = kwargs_common.copy()
kwargs_encoder.update(
{
argument[len("encoder_") :]: value
for argument, value in kwargs.items() for argument, value in kwargs.items()
if argument.startswith("decoder_") if argument.startswith("encoder_")
} }
kwargs_common = { )
argument: value kwargs_decoder.update(
{
argument[len("decoder_") :]: value
for argument, value in kwargs.items() for argument, value in kwargs.items()
if not (argument.startswith("encoder_") or argument.startswith("decoder_")) if argument.startswith("decoder_")
} }
kwargs_decoder = dict(kwargs_common, **kwargs_decoder) )
kwargs_encoder = dict(kwargs_common, **kwargs_encoder)
# Encode if needed (training, first prediction pass) # Encode if needed (training, first prediction pass)
encoder_hidden_states = kwargs_encoder.pop("hidden_states", None) encoder_hidden_states = kwargs_encoder.pop("hidden_states", None)
if encoder_hidden_states is None: if encoder_hidden_states is None:
encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder) encoder_outputs = self.encoder(encoder_input_ids, **kwargs_encoder)
encoder_hidden_states = encoder_outputs[0] # output the last layer hidden state encoder_hidden_states = encoder_outputs[
0
] # output the last layer hidden state
else: else:
encoder_outputs = () encoder_outputs = ()
# Decode # Decode
kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states kwargs_decoder["encoder_hidden_states"] = encoder_hidden_states
kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get("attention_mask", None) kwargs_decoder["encoder_attention_mask"] = kwargs_encoder.get(
"attention_mask", None
)
decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder) decoder_outputs = self.decoder(decoder_input_ids, **kwargs_decoder)
return decoder_outputs + encoder_outputs return decoder_outputs + encoder_outputs
...@@ -235,6 +249,7 @@ class Model2Model(PreTrainedEncoderDecoder): ...@@ -235,6 +249,7 @@ class Model2Model(PreTrainedEncoderDecoder):
decoder = BertForMaskedLM(config) decoder = BertForMaskedLM(config)
model = Model2Model.from_pretrained('bert-base-uncased', decoder_model=decoder) model = Model2Model.from_pretrained('bert-base-uncased', decoder_model=decoder)
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(Model2Model, self).__init__(*args, **kwargs) super(Model2Model, self).__init__(*args, **kwargs)
self.tie_weights() self.tie_weights()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment