Unverified Commit bf37e5c7 authored by ADAning's avatar ADAning Committed by GitHub
Browse files

Fix T5 incorrect weight decay in Trainer and official summarization example (#18002)

* Add ALL_LAYERNORM_LAYERS for LayerNorm

* fix bug of appending layer norm
parent 22edb68d
...@@ -526,7 +526,7 @@ def main(): ...@@ -526,7 +526,7 @@ def main():
# Optimizer # Optimizer
# Split weights in two groups, one with weight decay and the other not. # Split weights in two groups, one with weight decay and the other not.
no_decay = ["bias", "LayerNorm.weight"] no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight"]
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
{ {
"params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
......
...@@ -32,7 +32,8 @@ from ...modeling_outputs import ( ...@@ -32,7 +32,8 @@ from ...modeling_outputs import (
Seq2SeqLMOutput, Seq2SeqLMOutput,
Seq2SeqModelOutput, Seq2SeqModelOutput,
) )
from ...modeling_utils import PreTrainedModel, find_pruneable_heads_and_indices, prune_linear_layer from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ( from ...utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
DUMMY_MASK, DUMMY_MASK,
...@@ -260,6 +261,8 @@ except Exception: ...@@ -260,6 +261,8 @@ except Exception:
logger.warning("discovered apex but it failed to load, falling back to LongT5LayerNorm") logger.warning("discovered apex but it failed to load, falling back to LongT5LayerNorm")
pass pass
ALL_LAYERNORM_LAYERS.append(LongT5LayerNorm)
# Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->LongT5 # Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->LongT5
class LongT5DenseActDense(nn.Module): class LongT5DenseActDense(nn.Module):
......
...@@ -34,7 +34,7 @@ from ...modeling_outputs import ( ...@@ -34,7 +34,7 @@ from ...modeling_outputs import (
Seq2SeqModelOutput, Seq2SeqModelOutput,
) )
from ...modeling_utils import PreTrainedModel from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer from ...pytorch_utils import ALL_LAYERNORM_LAYERS, find_pruneable_heads_and_indices, prune_linear_layer
from ...utils import ( from ...utils import (
DUMMY_INPUTS, DUMMY_INPUTS,
DUMMY_MASK, DUMMY_MASK,
...@@ -275,6 +275,8 @@ except Exception: ...@@ -275,6 +275,8 @@ except Exception:
logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm") logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm")
pass pass
ALL_LAYERNORM_LAYERS.append(T5LayerNorm)
class T5DenseActDense(nn.Module): class T5DenseActDense(nn.Module):
def __init__(self, config: T5Config): def __init__(self, config: T5Config):
......
...@@ -21,6 +21,8 @@ from torch import _softmax_backward_data, nn ...@@ -21,6 +21,8 @@ from torch import _softmax_backward_data, nn
from .utils import logging from .utils import logging
ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
logger = logging.get_logger(__name__) logger = logging.get_logger(__name__)
is_torch_less_than_1_8 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.8.0") is_torch_less_than_1_8 = version.parse(version.parse(torch.__version__).base_version) < version.parse("1.8.0")
......
...@@ -71,6 +71,7 @@ from .modelcard import TrainingSummary ...@@ -71,6 +71,7 @@ from .modelcard import TrainingSummary
from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model from .modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES from .models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from .optimization import Adafactor, get_scheduler from .optimization import Adafactor, get_scheduler
from .pytorch_utils import ALL_LAYERNORM_LAYERS
from .tokenization_utils_base import PreTrainedTokenizerBase from .tokenization_utils_base import PreTrainedTokenizerBase
from .trainer_callback import ( from .trainer_callback import (
CallbackHandler, CallbackHandler,
...@@ -967,7 +968,7 @@ class Trainer: ...@@ -967,7 +968,7 @@ class Trainer:
opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model opt_model = self.model_wrapped if is_sagemaker_mp_enabled() else self.model
if self.optimizer is None: if self.optimizer is None:
decay_parameters = get_parameter_names(opt_model, [nn.LayerNorm]) decay_parameters = get_parameter_names(opt_model, ALL_LAYERNORM_LAYERS)
decay_parameters = [name for name in decay_parameters if "bias" not in name] decay_parameters = [name for name in decay_parameters if "bias" not in name]
optimizer_grouped_parameters = [ optimizer_grouped_parameters = [
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment