Unverified Commit 70974592 authored by RafaelWO's avatar RafaelWO Committed by GitHub
Browse files

Transformer-XL: Remove unused parameters (#7087)

* Removed 'tgt_len' and 'ext_len' from Transfomer-XL

 * Some changes are still to be done

* Removed 'tgt_len' and 'ext_len' from Transfomer-XL (2)

 * Removed comments
 * Fixed quality

* Changed warning to info
parent c183d81e
......@@ -88,7 +88,7 @@ def main():
)
)
model.reset_length(args.tgt_len, args.ext_len, args.mem_len)
model.reset_memory_length(args.mem_len)
if args.clamp_len > 0:
model.clamp_len = args.clamp_len
if args.same_length:
......
......@@ -62,10 +62,6 @@ class TransfoXLConfig(PretrainedConfig):
Apply LayerNorm to the input instead of the output
n_layer (:obj:`int`, optional, defaults to 18):
Number of hidden layers in the Transformer encoder.
tgt_len (:obj:`int`, optional, defaults to 128):
Number of tokens to predict
ext_len (:obj:`int`, optional, defaults to 0):
Length of the extended context
mem_len (:obj:`int`, optional, defaults to 1600):
Length of the retained previous heads
clamp_len (:obj:`int`, optional, defaults to 1000):
......@@ -125,8 +121,6 @@ class TransfoXLConfig(PretrainedConfig):
div_val=4,
pre_lnorm=False,
n_layer=18,
tgt_len=128,
ext_len=0,
mem_len=1600,
clamp_len=1000,
same_length=True,
......@@ -168,8 +162,6 @@ class TransfoXLConfig(PretrainedConfig):
self.pre_lnorm = pre_lnorm
self.n_layer = n_layer
self.n_head = n_head
self.tgt_len = tgt_len
self.ext_len = ext_len
self.mem_len = mem_len
self.same_length = same_length
self.attn_type = attn_type
......@@ -187,7 +179,9 @@ class TransfoXLConfig(PretrainedConfig):
@property
def max_position_embeddings(self):
return self.tgt_len + self.ext_len + self.mem_len
# Message copied from Transformer-XL documentation
logger.info(f"The model {self.model_type} is one of the few models that has no sequence length limit.")
return -1
@property
def n_token(self): # Backward compatibility
......
......@@ -15,8 +15,7 @@
# limitations under the License.
""" TF 2.0 Transformer XL model.
"""
import warnings
from dataclasses import dataclass
from typing import List, Optional, Tuple
......@@ -107,10 +106,7 @@ class TFRelPartialLearnableMultiHeadAttn(tf.keras.layers.Layer):
d_model,
d_head,
dropout,
dropatt=0,
tgt_len=None,
ext_len=None,
mem_len=None,
dropatt=0.0,
pre_lnorm=False,
r_r_bias=None,
r_w_bias=None,
......@@ -261,9 +257,6 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
d_head,
d_inner,
dropout,
tgt_len=None,
ext_len=None,
mem_len=None,
dropatt=0.0,
pre_lnorm=False,
r_w_bias=None,
......@@ -280,9 +273,6 @@ class TFRelPartialLearnableDecoderLayer(tf.keras.layers.Layer):
d_model,
d_head,
dropout,
tgt_len=tgt_len,
ext_len=ext_len,
mem_len=mem_len,
dropatt=dropatt,
pre_lnorm=pre_lnorm,
r_w_bias=r_w_bias,
......@@ -414,12 +404,7 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
self.drop = tf.keras.layers.Dropout(config.dropout)
self.n_layer = config.n_layer
self.tgt_len = config.tgt_len
self.mem_len = config.mem_len
self.ext_len = config.ext_len
self.max_klen = config.tgt_len + config.ext_len + config.mem_len
self.attn_type = config.attn_type
self.layers = []
......@@ -432,9 +417,6 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
config.d_head,
config.d_inner,
config.dropout,
tgt_len=config.tgt_len,
ext_len=config.ext_len,
mem_len=config.mem_len,
dropatt=config.dropatt,
pre_lnorm=config.pre_lnorm,
r_w_bias=None if self.untie_r else self.r_w_bias,
......@@ -478,10 +460,8 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
def backward_compatible(self):
self.sample_softmax = -1
def reset_length(self, tgt_len, ext_len, mem_len):
self.tgt_len = tgt_len
def reset_memory_length(self, mem_len):
self.mem_len = mem_len
self.ext_len = ext_len
def _prune_heads(self, heads):
raise NotImplementedError
......@@ -506,12 +486,8 @@ class TFTransfoXLMainLayer(tf.keras.layers.Layer):
assert len(hids) == len(mems), "len(hids) != len(mems)"
# There are `mlen + qlen` steps that can be cached into mems
# For the next step, the last `ext_len` of the `qlen` tokens
# will be used as the extended context. Hence, we only cache
# the tokens from `mlen + qlen - self.ext_len - self.mem_len`
# to `mlen + qlen - self.ext_len`.
new_mems = []
end_idx = mlen + max(0, qlen - 0 - self.ext_len)
end_idx = mlen + max(0, qlen)
beg_idx = max(0, end_idx - self.mem_len)
for i in range(len(hids)):
......@@ -867,7 +843,14 @@ class TFTransfoXLLMHeadModel(TFTransfoXLPreTrainedModel):
return None
def reset_length(self, tgt_len, ext_len, mem_len):
self.transformer.reset_length(tgt_len, ext_len, mem_len)
warnings.warn(
"The method `reset_length` is deprecated and will be removed in a future version, use `reset_memory_length` instead.",
FutureWarning,
)
self.transformer.reset_memory_length(mem_len)
def reset_memory_length(self, mem_len):
self.transformer.reset_memory_length(mem_len)
def init_mems(self, bsz):
return self.transformer.init_mems(bsz)
......
......@@ -17,8 +17,7 @@
Adapted from https://github.com/kimiyoung/transformer-xl.
In particular https://github.com/kimiyoung/transformer-xl/blob/master/pytorch/mem_transformer.py
"""
import warnings
from dataclasses import dataclass
from typing import List, Optional, Tuple
......@@ -234,9 +233,6 @@ class RelPartialLearnableMultiHeadAttn(nn.Module):
d_head,
dropout,
dropatt=0,
tgt_len=None,
ext_len=None,
mem_len=None,
pre_lnorm=False,
r_r_bias=None,
r_w_bias=None,
......@@ -737,12 +733,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
self.drop = nn.Dropout(config.dropout)
self.n_layer = config.n_layer
self.tgt_len = config.tgt_len
self.mem_len = config.mem_len
self.ext_len = config.ext_len
self.max_klen = config.tgt_len + config.ext_len + config.mem_len
self.attn_type = config.attn_type
if not config.untie_r:
......@@ -759,9 +750,6 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
config.d_head,
config.d_inner,
config.dropout,
tgt_len=config.tgt_len,
ext_len=config.ext_len,
mem_len=config.mem_len,
dropatt=config.dropatt,
pre_lnorm=config.pre_lnorm,
r_w_bias=None if config.untie_r else self.r_w_bias,
......@@ -791,10 +779,8 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
def backward_compatible(self):
self.sample_softmax = -1
def reset_length(self, tgt_len, ext_len, mem_len):
self.tgt_len = tgt_len
def reset_memory_length(self, mem_len):
self.mem_len = mem_len
self.ext_len = ext_len
def _prune_heads(self, heads):
logger.info("Head pruning is not implemented for Transformer-XL model")
......@@ -821,13 +807,9 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
assert len(hids) == len(mems), "len(hids) != len(mems)"
# There are `mlen + qlen` steps that can be cached into mems
# For the next step, the last `ext_len` of the `qlen` tokens
# will be used as the extended context. Hence, we only cache
# the tokens from `mlen + qlen - self.ext_len - self.mem_len`
# to `mlen + qlen - self.ext_len`.
with torch.no_grad():
new_mems = []
end_idx = mlen + max(0, qlen - 0 - self.ext_len)
end_idx = mlen + max(0, qlen)
beg_idx = max(0, end_idx - self.mem_len)
for i in range(len(hids)):
......@@ -1010,7 +992,14 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
self.crit.out_projs[i] = self.transformer.word_emb.emb_projs[i]
def reset_length(self, tgt_len, ext_len, mem_len):
self.transformer.reset_length(tgt_len, ext_len, mem_len)
warnings.warn(
"The method `reset_length` is deprecated and will be removed in a future version, use `reset_memory_length` instead.",
FutureWarning,
)
self.transformer.reset_memory_length(mem_len)
def reset_memory_length(self, mem_len):
self.transformer.reset_memory_length(mem_len)
def init_mems(self, bsz):
return self.transformer.init_mems(bsz)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment