Unverified Commit ddfdf470 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

Add trust_remote_code arg to get_config (#405)

parent b6fbb9a5
...@@ -54,7 +54,7 @@ class ModelConfig: ...@@ -54,7 +54,7 @@ class ModelConfig:
self.use_dummy_weights = use_dummy_weights self.use_dummy_weights = use_dummy_weights
self.seed = seed self.seed = seed
self.hf_config = get_config(model) self.hf_config = get_config(model, trust_remote_code)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self._verify_tokenizer_mode() self._verify_tokenizer_mode()
......
...@@ -7,8 +7,21 @@ _CONFIG_REGISTRY = { ...@@ -7,8 +7,21 @@ _CONFIG_REGISTRY = {
} }
def get_config(model: str) -> PretrainedConfig: def get_config(model: str, trust_remote_code: bool) -> PretrainedConfig:
config = AutoConfig.from_pretrained(model, trust_remote_code=True) try:
config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code)
except ValueError as e:
if (not trust_remote_code and
"requires you to execute the configuration file" in str(e)):
err_msg = (
"Failed to load the model config. If the model is a custom "
"model not yet available in the HuggingFace transformers "
"library, consider setting `trust_remote_code=True` in LLM "
"or using the `--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e
else:
raise e
if config.model_type in _CONFIG_REGISTRY: if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type] config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model) config = config_class.from_pretrained(model)
......
...@@ -34,8 +34,8 @@ def get_tokenizer( ...@@ -34,8 +34,8 @@ def get_tokenizer(
try: try:
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name, tokenizer_name,
trust_remote_code=trust_remote_code,
*args, *args,
trust_remote_code=trust_remote_code,
**kwargs) **kwargs)
except TypeError as e: except TypeError as e:
# The LLaMA tokenizer causes a protobuf error in some environments. # The LLaMA tokenizer causes a protobuf error in some environments.
...@@ -47,13 +47,14 @@ def get_tokenizer( ...@@ -47,13 +47,14 @@ def get_tokenizer(
except ValueError as e: except ValueError as e:
# If the error pertains to the tokenizer class not existing or not # If the error pertains to the tokenizer class not existing or not
# currently being imported, suggest using the --trust-remote-code flag. # currently being imported, suggest using the --trust-remote-code flag.
if (e is not None and if (not trust_remote_code and
("does not exist or is not currently imported." in str(e) ("does not exist or is not currently imported." in str(e)
or "requires you to execute the tokenizer file" in str(e))): or "requires you to execute the tokenizer file" in str(e))):
err_msg = ( err_msg = (
"Failed to load the tokenizer. If the tokenizer is a custom " "Failed to load the tokenizer. If the tokenizer is a custom "
"tokenizer not yet available in the HuggingFace transformers " "tokenizer not yet available in the HuggingFace transformers "
"library, consider using the --trust-remote-code flag.") "library, consider setting `trust_remote_code=True` in LLM "
"or using the `--trust-remote-code` flag in the CLI.")
raise RuntimeError(err_msg) from e raise RuntimeError(err_msg) from e
else: else:
raise e raise e
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment