Unverified Commit 1b653010 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[Examples] create model with custom config on the fly (#11798)



* create custom model on the flight

* better wording

* add update_from_string

* cleanup

* cleanup

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* more bool options

* style

* fix logger

* add test

* add the doc

* assert on conflict of options
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 6287c929
...@@ -161,3 +161,21 @@ concatenates all texts and then splits them in blocks of the same length). ...@@ -161,3 +161,21 @@ concatenates all texts and then splits them in blocks of the same length).
**Note:** On TPU, you should use the flag `--pad_to_max_length` in conjunction with the `--line_by_line` flag to make **Note:** On TPU, you should use the flag `--pad_to_max_length` in conjunction with the `--line_by_line` flag to make
sure all your batches have the same length. sure all your batches have the same length.
## Creating a model on the fly
When training a model from scratch, configuration values may be overridden with the help of `--config_overrides`:
```bash
python run_clm.py --model_type gpt2 --tokenizer_name gpt2 \ --config_overrides="n_embd=1024,n_head=16,n_layer=48,n_positions=102" \
[...]
```
At the moment this is only available in `run_clm.py` but eventually should be copied to all other LM examples.
This feature can also be used to activate gradient checkpointing by passing:
```
--config_overrides "gradient_checkpointing=true,use_cache=False"
```
...@@ -75,6 +75,13 @@ class ModelArguments: ...@@ -75,6 +75,13 @@ class ModelArguments:
default=None, default=None,
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)}, metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
) )
config_overrides: Optional[str] = field(
default=None,
metadata={
"help": "Override some existing default config settings when a model is trained from scratch. Example: "
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
},
)
config_name: Optional[str] = field( config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"} default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
) )
...@@ -101,6 +108,12 @@ class ModelArguments: ...@@ -101,6 +108,12 @@ class ModelArguments:
}, },
) )
def __post_init__(self):
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
raise ValueError(
"--config_overrides can't be used in combination with --config_name or --model_name_or_path"
)
@dataclass @dataclass
class DataTrainingArguments: class DataTrainingArguments:
...@@ -279,6 +292,9 @@ def main(): ...@@ -279,6 +292,9 @@ def main():
else: else:
config = CONFIG_MAPPING[model_args.model_type]() config = CONFIG_MAPPING[model_args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.") logger.warning("You are instantiating a new config instance from scratch.")
if model_args.config_overrides is not None:
logger.info(f"Overriding config: {model_args.config_overrides}")
config.update_from_string(model_args.config_overrides)
tokenizer_kwargs = { tokenizer_kwargs = {
"cache_dir": model_args.cache_dir, "cache_dir": model_args.cache_dir,
...@@ -306,8 +322,9 @@ def main(): ...@@ -306,8 +322,9 @@ def main():
use_auth_token=True if model_args.use_auth_token else None, use_auth_token=True if model_args.use_auth_token else None,
) )
else: else:
logger.info("Training new model from scratch")
model = AutoModelForCausalLM.from_config(config) model = AutoModelForCausalLM.from_config(config)
n_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
model.resize_token_embeddings(len(tokenizer)) model.resize_token_embeddings(len(tokenizer))
......
...@@ -667,7 +667,45 @@ class PretrainedConfig(PushToHubMixin): ...@@ -667,7 +667,45 @@ class PretrainedConfig(PushToHubMixin):
Updates attributes of this class with attributes from ``config_dict``. Updates attributes of this class with attributes from ``config_dict``.
Args: Args:
config_dict (:obj:`Dict[str, Any]`): Dictionary of attributes that shall be updated for this class. config_dict (:obj:`Dict[str, Any]`): Dictionary of attributes that should be updated for this class.
""" """
for key, value in config_dict.items(): for key, value in config_dict.items():
setattr(self, key, value) setattr(self, key, value)
def update_from_string(self, update_str: str):
"""
Updates attributes of this class with attributes from ``update_str``.
The expected format is ints, floats and strings as is, and for booleans use ``true`` or ``false``. For example:
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
The keys to change have to already exist in the config object.
Args:
update_str (:obj:`str`): String with attributes that should be updated for this class.
"""
d = dict(x.split("=") for x in update_str.split(","))
for k, v in d.items():
if not hasattr(self, k):
raise ValueError(f"key {k} isn't in the original config dict")
old_v = getattr(self, k)
if isinstance(old_v, bool):
if v.lower() in ["true", "1", "y", "yes"]:
v = True
elif v.lower() in ["false", "0", "n", "no"]:
v = False
else:
raise ValueError(f"can't derive true or false from {v} (key {k})")
elif isinstance(old_v, int):
v = int(v)
elif isinstance(old_v, float):
v = float(v)
elif not isinstance(old_v, str):
raise ValueError(
f"You can only update int, float, bool or string values in the config, got {v} for key {k}"
)
setattr(self, k, v)
...@@ -21,7 +21,7 @@ import unittest ...@@ -21,7 +21,7 @@ import unittest
from huggingface_hub import HfApi from huggingface_hub import HfApi
from requests.exceptions import HTTPError from requests.exceptions import HTTPError
from transformers import BertConfig from transformers import BertConfig, GPT2Config
from transformers.testing_utils import ENDPOINT_STAGING, PASS, USER, is_staging_test from transformers.testing_utils import ENDPOINT_STAGING, PASS, USER, is_staging_test
...@@ -138,3 +138,21 @@ class ConfigPushToHubTester(unittest.TestCase): ...@@ -138,3 +138,21 @@ class ConfigPushToHubTester(unittest.TestCase):
for k, v in config.__dict__.items(): for k, v in config.__dict__.items():
if k != "transformers_version": if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k)) self.assertEqual(v, getattr(new_config, k))
class ConfigTestUtils(unittest.TestCase):
def test_config_from_string(self):
c = GPT2Config()
# attempt to modify each of int/float/bool/str config records and verify they were updated
n_embd = c.n_embd + 1 # int
resid_pdrop = c.resid_pdrop + 1.0 # float
scale_attn_weights = not c.scale_attn_weights # bool
summary_type = c.summary_type + "foo" # str
c.update_from_string(
f"n_embd={n_embd},resid_pdrop={resid_pdrop},scale_attn_weights={scale_attn_weights},summary_type={summary_type}"
)
self.assertEqual(n_embd, c.n_embd, "mismatch for key: n_embd")
self.assertEqual(resid_pdrop, c.resid_pdrop, "mismatch for key: resid_pdrop")
self.assertEqual(scale_attn_weights, c.scale_attn_weights, "mismatch for key: scale_attn_weights")
self.assertEqual(summary_type, c.summary_type, "mismatch for key: summary_type")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment