Unverified Commit 9de62cfb authored by Kumar Abhishek's avatar Kumar Abhishek Committed by GitHub
Browse files

[lm examples] Replicate --config_overrides addition to other LM examples (#12135)



* [lm examples] Replicate --config_overrides addition to other LM examples

* Removing no trainer files changes

* Update README
Co-authored-by: default avatarKumar Abhishek <kabhishek@expedia.com>
parent cd7961b6
...@@ -173,7 +173,7 @@ python run_clm.py --model_type gpt2 --tokenizer_name gpt2 \ --config_overrides=" ...@@ -173,7 +173,7 @@ python run_clm.py --model_type gpt2 --tokenizer_name gpt2 \ --config_overrides="
[...] [...]
``` ```
At the moment this is only available in `run_clm.py` but eventually should be copied to all other LM examples. This feature is only available in `run_clm.py`, `run_plm.py` and `run_mlm.py`.
This feature can also be used to activate gradient checkpointing by passing: This feature can also be used to activate gradient checkpointing by passing:
``` ```
......
...@@ -72,6 +72,13 @@ class ModelArguments: ...@@ -72,6 +72,13 @@ class ModelArguments:
default=None, default=None,
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
) )
config_overrides: Optional[str] = field(
default=None,
metadata={
"help": "Override some existing default config settings when a model is trained from scratch. Example: "
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
},
)
config_name: Optional[str] = field( config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
) )
...@@ -98,6 +105,12 @@ class ModelArguments: ...@@ -98,6 +105,12 @@ class ModelArguments:
}, },
) )
def __post_init__(self):
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
raise ValueError(
"--config_overrides can't be used in combination with --config_name or --model_name_or_path"
)
@dataclass @dataclass
class DataTrainingArguments: class DataTrainingArguments:
...@@ -283,6 +296,9 @@ def main(): ...@@ -283,6 +296,9 @@ def main():
else: else:
config = CONFIG_MAPPING[model_args.model_type]() config = CONFIG_MAPPING[model_args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.") logger.warning("You are instantiating a new config instance from scratch.")
if model_args.config_overrides is not None:
logger.info(f"Overriding config: {model_args.config_overrides}")
config.update_from_string(model_args.config_overrides)
tokenizer_kwargs = { tokenizer_kwargs = {
"cache_dir": model_args.cache_dir, "cache_dir": model_args.cache_dir,
......
...@@ -65,6 +65,13 @@ class ModelArguments: ...@@ -65,6 +65,13 @@ class ModelArguments:
config_name: Optional[str] = field( config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
) )
config_overrides: Optional[str] = field(
default=None,
metadata={
"help": "Override some existing default config settings when a model is trained from scratch. Example: "
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
},
)
tokenizer_name: Optional[str] = field( tokenizer_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"} default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
) )
...@@ -88,6 +95,12 @@ class ModelArguments: ...@@ -88,6 +95,12 @@ class ModelArguments:
}, },
) )
def __post_init__(self):
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
raise ValueError(
"--config_overrides can't be used in combination with --config_name or --model_name_or_path"
)
@dataclass @dataclass
class DataTrainingArguments: class DataTrainingArguments:
...@@ -280,6 +293,9 @@ def main(): ...@@ -280,6 +293,9 @@ def main():
else: else:
config = XLNetConfig() config = XLNetConfig()
logger.warning("You are instantiating a new config instance from scratch.") logger.warning("You are instantiating a new config instance from scratch.")
if model_args.config_overrides is not None:
logger.info(f"Overriding config: {model_args.config_overrides}")
config.update_from_string(model_args.config_overrides)
tokenizer_kwargs = { tokenizer_kwargs = {
"cache_dir": model_args.cache_dir, "cache_dir": model_args.cache_dir,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment