Unverified Commit 70974592 authored by RafaelWO's avatar RafaelWO Committed by GitHub
Browse files

Transformer-XL: Remove unused parameters (#7087)

* Removed 'tgt_len' and 'ext_len' from Transfomer-XL

 * Some changes are still to be done

* Removed 'tgt_len' and 'ext_len' from Transfomer-XL (2)

 * Removed comments
 * Fixed quality

* Changed warning to info
parent c183d81e
...@@ -88,7 +88,7 @@ def main(): ...@@ -88,7 +88,7 @@ def main():
) )
) )
model.reset_length(args.tgt_len, args.ext_len, args.mem_len) model.reset_memory_length(args.mem_len)
if args.clamp_len > 0: if args.clamp_len > 0:
model.clamp_len = args.clamp_len model.clamp_len = args.clamp_len
if args.same_length: if args.same_length:
......
...@@ -62,10 +62,6 @@ class TransfoXLConfig(PretrainedConfig): ...@@ -62,10 +62,6 @@ class TransfoXLConfig(PretrainedConfig):
Apply LayerNorm to the input instead of the output Apply LayerNorm to the input instead of the output
n_layer (:obj:`int`, optional, defaults to 18): n_layer (:obj:`int`, optional, defaults to 18):
Number of hidden layers in the Transformer encoder. Number of hidden layers in the Transformer encoder.
tgt_len (:obj:`int`, optional, defaults to 128):
Number of tokens to predict
ext_len (:obj:`int`, optional, defaults to 0):
Length of the extended context
mem_len (:obj:`int`, optional, defaults to 1600): mem_len (:obj:`int`, optional, defaults to 1600):
Length of the retained previous heads Length of the retained previous heads
clamp_len (:obj:`int`, optional, defaults to 1000): clamp_len (:obj:`int`, optional, defaults to 1000):
...@@ -125,8 +121,6 @@ class TransfoXLConfig(PretrainedConfig): ...@@ -125,8 +121,6 @@ class TransfoXLConfig(PretrainedConfig):
div_val=4, div_val=4,
pre_lnorm=False, pre_lnorm=False,
n_layer=18, n_layer=18,
tgt_len=128,
ext_len=0,
mem_len=1600, mem_len=1600,
clamp_len=1000, clamp_len=1000,
same_length=True, same_length=True,
...@@ -168,8 +162,6 @@ class TransfoXLConfig(PretrainedConfig): ...@@ -168,8 +162,6 @@ class TransfoXLConfig(PretrainedConfig):
self.pre_lnorm = pre_lnorm self.pre_lnorm = pre_lnorm
self.n_layer = n_layer self.n_layer = n_layer
self.n_head = n_head self.n_head = n_head
self.tgt_len = tgt_len
self.ext_len = ext_len
self.mem_len = mem_len self.mem_len = mem_len
self.same_length = same_length self.same_length = same_length
self.attn_type = attn_type self.attn_type = attn_type
...@@ -187,7 +179,9 @@ class TransfoXLConfig(PretrainedConfig): ...@@ -187,7 +179,9 @@ class TransfoXLConfig(PretrainedConfig):
@property @property
def max_position_embeddings(self): def max_position_embeddings(self):
return self.tgt_len + self.ext_len + self.mem_len # Message copied from Transformer-XL documentation
logger.info(f"The model {self.model_type} is one of the few models that has no sequence length limit.")
return -1
@property @property
def n_token(self): # Backward compatibility def n_token(self): # Backward compatibility
......
...@@ -15,8 +15,7 @@ ...@@ -15,8 +15,7 @@
# limitations under the License. # limitations under the License.
""" TF 2.0 Transformer XL model. """ TF 2.0 Transformer XL model.
""" """
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
...@@ -107,10 +106,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer): ...@@ -107,10 +106,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
d_model, d_model,
d_head, d_head,
dropout, dropout,
dropatt=0, dropatt=0.0,
tgt_len=None,
ext_len=None,
mem_len=None,
pre_lnorm=False, pre_lnorm=False,
r_r_bias=None, r_r_bias=None,
r_w_bias=None, r_w_bias=None,
...@@ -261,9 +257,6 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer): ...@@ -261,9 +257,6 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
d_head, d_head,
d_inner, d_inner,
dropout, dropout,
tgt_len=None,
ext_len=None,
mem_len=None,
dropatt=0.0, dropatt=0.0,
pre_lnorm=False, pre_lnorm=False,
r_w_bias=None, r_w_bias=None,
...@@ -280,9 +273,6 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer): ...@@ -280,9 +273,6 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
d_model, d_model,
d_head, d_head,
dropout, dropout,
tgt_len=tgt_len,
ext_len=ext_len,
mem_len=mem_len,
dropatt=dropatt, dropatt=dropatt,
pre_lnorm=pre_lnorm, pre_lnorm=pre_lnorm,
r_w_bias=r_w_bias, r_w_bias=r_w_bias,
...@@ -414,12 +404,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -414,12 +404,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
self.drop = tf.keras.layers.Dropout(config.dropout) self.drop = tf.keras.layers.Dropout(config.dropout)
self.n_layer = config.n_layer self.n_layer = config.n_layer
self.tgt_len = config.tgt_len
self.mem_len = config.mem_len self.mem_len = config.mem_len
self.ext_len = config.ext_len
self.max_klen = config.tgt_len + config.ext_len + config.mem_len
self.attn_type = config.attn_type self.attn_type = config.attn_type
self.layers = [] self.layers = []
...@@ -432,9 +417,6 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -432,9 +417,6 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
config.d_head, config.d_head,
config.d_inner, config.d_inner,
config.dropout, config.dropout,
tgt_len=config.tgt_len,
ext_len=config.ext_len,
mem_len=config.mem_len,
dropatt=config.dropatt, dropatt=config.dropatt,
pre_lnorm=config.pre_lnorm, pre_lnorm=config.pre_lnorm,
r_w_bias=None if self.untie_r else self.r_w_bias, r_w_bias=None if self.untie_r else self.r_w_bias,
...@@ -478,10 +460,8 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -478,10 +460,8 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
def backward_compatible(self): def backward_compatible(self):
self.sample_softmax = -1 self.sample_softmax = -1
def reset_length(self, tgt_len, ext_len, mem_len): def reset_memory_length(self, mem_len):
self.tgt_len = tgt_len
self.mem_len = mem_len self.mem_len = mem_len
self.ext_len = ext_len
def _prune_heads(self, heads): def _prune_heads(self, heads):
raise NotImplementedError raise NotImplementedError
...@@ -506,12 +486,8 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer): ...@@ -506,12 +486,8 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
assert len(hids) == len(mems), "len(hids) != len(mems)" assert len(hids) == len(mems), "len(hids) != len(mems)"
# There are `mlen + qlen` steps that can be cached into mems # There are `mlen + qlen` steps that can be cached into mems
# For the next step, the last `ext_len` of the `qlen` tokens
# will be used as the extended context. Hence, we only cache
# the tokens from `mlen + qlen - self.ext_len - self.mem_len`
# to `mlen + qlen - self.ext_len`.
new_mems = [] new_mems = []
end_idx = mlen + max(0, qlen - 0 - self.ext_len) end_idx = mlen + max(0, qlen)
beg_idx = max(0, end_idx - self.mem_len) beg_idx = max(0, end_idx - self.mem_len)
for i in range(len(hids)): for i in range(len(hids)):
...@@ -867,7 +843,14 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel): ...@@ -867,7 +843,14 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
return None return None
def reset_length(self, tgt_len, ext_len, mem_len): def reset_length(self, tgt_len, ext_len, mem_len):
self.transformer.reset_length(tgt_len, ext_len, mem_len) warnings.warn(
"The method `reset_length` is deprecated and will be removed in a future version, use `reset_memory_length` instead.",
FutureWarning,
)
self.transformer.reset_memory_length(mem_len)
def reset_memory_length(self, mem_len):
self.transformer.reset_memory_length(mem_len)
def init_mems(self, bsz): def init_mems(self, bsz):
return self.transformer.init_mems(bsz) return self.transformer.init_mems(bsz)
......
...@@ -17,8 +17,7 @@ ...@@ -17,8 +17,7 @@
Adapted from https://github.com/kimiyoung/transformer-xl. Adapted from https://github.com/kimiyoung/transformer-xl.
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
""" """
import warnings
from dataclasses import dataclass from dataclasses import dataclass
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
...@@ -234,9 +233,6 @@ class RelPartialLearnableMultiHeadAttn(nn.Module): ...@@ -234,9 +233,6 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
d_head, d_head,
dropout, dropout,
dropatt=0, dropatt=0,
tgt_len=None,
ext_len=None,
mem_len=None,
pre_lnorm=False, pre_lnorm=False,
r_r_bias=None, r_r_bias=None,
r_w_bias=None, r_w_bias=None,
...@@ -737,12 +733,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -737,12 +733,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
self.drop = nn.Dropout(config.dropout) self.drop = nn.Dropout(config.dropout)
self.n_layer = config.n_layer self.n_layer = config.n_layer
self.tgt_len = config.tgt_len
self.mem_len = config.mem_len self.mem_len = config.mem_len
self.ext_len = config.ext_len
self.max_klen = config.tgt_len + config.ext_len + config.mem_len
self.attn_type = config.attn_type self.attn_type = config.attn_type
if not config.untie_r: if not config.untie_r:
...@@ -759,9 +750,6 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -759,9 +750,6 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
config.d_head, config.d_head,
config.d_inner, config.d_inner,
config.dropout, config.dropout,
tgt_len=config.tgt_len,
ext_len=config.ext_len,
mem_len=config.mem_len,
dropatt=config.dropatt, dropatt=config.dropatt,
pre_lnorm=config.pre_lnorm, pre_lnorm=config.pre_lnorm,
r_w_bias=None if config.untie_r else self.r_w_bias, r_w_bias=None if config.untie_r else self.r_w_bias,
...@@ -791,10 +779,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -791,10 +779,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
def backward_compatible(self): def backward_compatible(self):
self.sample_softmax = -1 self.sample_softmax = -1
def reset_length(self, tgt_len, ext_len, mem_len): def reset_memory_length(self, mem_len):
self.tgt_len = tgt_len
self.mem_len = mem_len self.mem_len = mem_len
self.ext_len = ext_len
def _prune_heads(self, heads): def _prune_heads(self, heads):
logger.info("Head pruning is not implemented for Transformer-XL model") logger.info("Head pruning is not implemented for Transformer-XL model")
...@@ -821,13 +807,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel): ...@@ -821,13 +807,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
assert len(hids) == len(mems), "len(hids) != len(mems)" assert len(hids) == len(mems), "len(hids) != len(mems)"
# There are `mlen + qlen` steps that can be cached into mems # There are `mlen + qlen` steps that can be cached into mems
# For the next step, the last `ext_len` of the `qlen` tokens
# will be used as the extended context. Hence, we only cache
# the tokens from `mlen + qlen - self.ext_len - self.mem_len`
# to `mlen + qlen - self.ext_len`.
with torch.no_grad(): with torch.no_grad():
new_mems = [] new_mems = []
end_idx = mlen + max(0, qlen - 0 - self.ext_len) end_idx = mlen + max(0, qlen)
beg_idx = max(0, end_idx - self.mem_len) beg_idx = max(0, end_idx - self.mem_len)
for i in range(len(hids)): for i in range(len(hids)):
...@@ -1010,7 +992,14 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel): ...@@ -1010,7 +992,14 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i] self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i]
def reset_length(self, tgt_len, ext_len, mem_len): def reset_length(self, tgt_len, ext_len, mem_len):
self.transformer.reset_length(tgt_len, ext_len, mem_len) warnings.warn(
"The method `reset_length` is deprecated and will be removed in a future version, use `reset_memory_length` instead.",
FutureWarning,
)
self.transformer.reset_memory_length(mem_len)
def reset_memory_length(self, mem_len):
self.transformer.reset_memory_length(mem_len)
def init_mems(self, bsz): def init_mems(self, bsz):
return self.transformer.init_mems(bsz) return self.transformer.init_mems(bsz)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment