Unverified Commit 1b653010 authored by Stas Bekman's avatar Stas Bekman Committed by GitHub
Browse files

[Examples] create model with custom config on the fly (#11798)



* create custom model on the flight

* better wording

* add update_from_string

* cleanup

* cleanup

* Update src/transformers/configuration_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

* more bool options

* style

* fix logger

* add test

* add the doc

* assert on conflict of options
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 6287c929
......@@ -161,3 +161,21 @@ concatenates all texts and then splits them in blocks of the same length).
**Note:** On TPU, you should use the flag `--pad_to_max_length` in conjunction with the `--line_by_line` flag to make
sure all your batches have the same length.
## Creating a model on the fly
When training a model from scratch, configuration values may be overridden with the help of `--config_overrides`:
```bash
python run_clm.py --model_type gpt2 --tokenizer_name gpt2 \ --config_overrides="n_embd=1024,n_head=16,n_layer=48,n_positions=102" \
[...]
```
At the moment this is only available in `run_clm.py` but eventually should be copied to all other LM examples.
This feature can also be used to activate gradient checkpointing by passing:
```
--config_overrides "gradient_checkpointing=true,use_cache=False"
```
......@@ -75,6 +75,13 @@ class ModelArguments:
default=None,
metadata={"help": "If training from scratch, pass a model type from the list: " + ", ".join(MODEL_TYPES)},
)
config_overrides: Optional[str] = field(
default=None,
metadata={
"help": "Override some existing default config settings when a model is trained from scratch. Example: "
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
},
)
config_name: Optional[str] = field(
default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
)
......@@ -101,6 +108,12 @@ class ModelArguments:
},
)
def __post_init__(self):
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
raise ValueError(
"--config_overrides can't be used in combination with --config_name or --model_name_or_path"
)
@dataclass
class DataTrainingArguments:
......@@ -279,6 +292,9 @@ def main():
else:
config = CONFIG_MAPPING[model_args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.")
if model_args.config_overrides is not None:
logger.info(f"Overriding config: {model_args.config_overrides}")
config.update_from_string(model_args.config_overrides)
tokenizer_kwargs = {
"cache_dir": model_args.cache_dir,
......@@ -306,8 +322,9 @@ def main():
use_auth_token=True if model_args.use_auth_token else None,
)
else:
logger.info("Training new model from scratch")
model = AutoModelForCausalLM.from_config(config)
n_params = sum(dict((p.data_ptr(), p.numel()) for p in model.parameters()).values())
logger.info(f"Training new model from scratch - Total size={n_params/2**20:.2f}M params")
model.resize_token_embeddings(len(tokenizer))
......
......@@ -667,7 +667,45 @@ class PretrainedConfig(PushToHubMixin):
Updates attributes of this class with attributes from ``config_dict``.
Args:
config_dict (:obj:`Dict[str, Any]`): Dictionary of attributes that shall be updated for this class.
config_dict (:obj:`Dict[str, Any]`): Dictionary of attributes that should be updated for this class.
"""
for key, value in config_dict.items():
setattr(self, key, value)
def update_from_string(self, update_str: str):
"""
Updates attributes of this class with attributes from ``update_str``.
The expected format is ints, floats and strings as is, and for booleans use ``true`` or ``false``. For example:
"n_embd=10,resid_pdrop=0.2,scale_attn_weights=false,summary_type=cls_index"
The keys to change have to already exist in the config object.
Args:
update_str (:obj:`str`): String with attributes that should be updated for this class.
"""
d = dict(x.split("=") for x in update_str.split(","))
for k, v in d.items():
if not hasattr(self, k):
raise ValueError(f"key {k} isn't in the original config dict")
old_v = getattr(self, k)
if isinstance(old_v, bool):
if v.lower() in ["true", "1", "y", "yes"]:
v = True
elif v.lower() in ["false", "0", "n", "no"]:
v = False
else:
raise ValueError(f"can't derive true or false from {v} (key {k})")
elif isinstance(old_v, int):
v = int(v)
elif isinstance(old_v, float):
v = float(v)
elif not isinstance(old_v, str):
raise ValueError(
f"You can only update int, float, bool or string values in the config, got {v} for key {k}"
)
setattr(self, k, v)
......@@ -21,7 +21,7 @@ import unittest
from huggingface_hub import HfApi
from requests.exceptions import HTTPError
from transformers import BertConfig
from transformers import BertConfig, GPT2Config
from transformers.testing_utils import ENDPOINT_STAGING, PASS, USER, is_staging_test
......@@ -138,3 +138,21 @@ class ConfigPushToHubTester(unittest.TestCase):
for k, v in config.__dict__.items():
if k != "transformers_version":
self.assertEqual(v, getattr(new_config, k))
class ConfigTestUtils(unittest.TestCase):
def test_config_from_string(self):
c = GPT2Config()
# attempt to modify each of int/float/bool/str config records and verify they were updated
n_embd = c.n_embd + 1 # int
resid_pdrop = c.resid_pdrop + 1.0 # float
scale_attn_weights = not c.scale_attn_weights # bool
summary_type = c.summary_type + "foo" # str
c.update_from_string(
f"n_embd={n_embd},resid_pdrop={resid_pdrop},scale_attn_weights={scale_attn_weights},summary_type={summary_type}"
)
self.assertEqual(n_embd, c.n_embd, "mismatch for key: n_embd")
self.assertEqual(resid_pdrop, c.resid_pdrop, "mismatch for key: resid_pdrop")
self.assertEqual(scale_attn_weights, c.scale_attn_weights, "mismatch for key: scale_attn_weights")
self.assertEqual(summary_type, c.summary_type, "mismatch for key: summary_type")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment