Unverified Commit ddfdf470 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Add trust_remote_code arg to get_config (#405)

parent b6fbb9a5
......@@ -54,7 +54,7 @@ class ModelConfig:
self.use_dummy_weights = use_dummy_weights
self.seed = seed
self.hf_config = get_config(model)
self.hf_config = get_config(model, trust_remote_code)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self._verify_tokenizer_mode()
......
......@@ -7,8 +7,21 @@ _CONFIG_REGISTRY = {
}
def get_config(model: str) -> PretrainedConfig:
config = AutoConfig.from_pretrained(model, trust_remote_code=True)
def get_config(model: str, trust_remote_code: bool) -> PretrainedConfig:
try:
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code)
except ValueError as e:
if (not trust_remote_code and
"requires you to execute the configuration file" in str(e)):
err_msg = (
"Failed to load the model config. If the model is a custom "
"model not yet available in the HuggingFace transformers "
"library, consider setting `trust_remote_code=True` in LLM "
"or using the `--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model)
......
......@@ -34,8 +34,8 @@ def get_tokenizer(
try:
tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name,
trust_remote_code=trust_remote_code,
*args,
trust_remote_code=trust_remote_code,
**kwargs)
except TypeError as e:
# The LLaMA tokenizer causes a protobuf error in some environments.
......@@ -47,13 +47,14 @@ def get_tokenizer(
except ValueError as e:
# If the error pertains to the tokenizer class not existing or not
# currently being imported, suggest using the --trust-remote-code flag.
if (e is not None and
if (not trust_remote_code and
("does not exist or is not currently imported." in str(e)
or "requires you to execute the tokenizer file" in str(e))):
err_msg = (
"Failed to load the tokenizer. If the tokenizer is a custom "
"tokenizer not yet available in the HuggingFace transformers "
"library, consider using the --trust-remote-code flag.")
"library, consider setting `trust_remote_code=True` in LLM "
"or using the `--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment