Unverified Commit 14510938 authored by Jackmin801's avatar Jackmin801 Committed by GitHub
Browse files

Allow `trust_remote_code` in example scripts (#25248)

* pytorch examples

* pytorch mim no trainer

* cookiecutter

* flax examples

* missed line in pytorch run_glue

* tensorflow examples

* tensorflow run_clip

* tensorflow run_mlm

* tensorflow run_ner

* tensorflow run_clm

* pytorch example from_configs

* pytorch no trainer examples

* Revert "tensorflow run_clip"

This reverts commit 261f86ac1f1c9e05dd3fd0291e1a1f8e573781d5.

* fix: duplicated argument
parent 65001cb1
...@@ -126,6 +126,16 @@ class ModelArguments: ...@@ -126,6 +126,16 @@ class ModelArguments:
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`." "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`."
}, },
) )
trust_remote_code: bool = field(
default=False,
metadata={
"help": (
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
"should only be set to `True` for repositories you trust and in which you have read the code, as it will"
"execute code present on the Hub on your local machine."
)
},
)
def __post_init__(self): def __post_init__(self):
if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None): if self.config_overrides is not None and (self.config_name is not None or self.model_name_or_path is not None):
...@@ -348,19 +358,25 @@ def main(): ...@@ -348,19 +358,25 @@ def main():
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab. # download model & vocab.
if checkpoint is not None: if checkpoint is not None:
config = AutoConfig.from_pretrained(checkpoint) config = AutoConfig.from_pretrained(checkpoint, trust_remote_code=model_args.trust_remote_code)
elif model_args.config_name: elif model_args.config_name:
config = AutoConfig.from_pretrained(model_args.config_name) config = AutoConfig.from_pretrained(model_args.config_name, trust_remote_code=model_args.trust_remote_code)
elif model_args.model_name_or_path: elif model_args.model_name_or_path:
config = AutoConfig.from_pretrained(model_args.model_name_or_path) config = AutoConfig.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
else: else:
config = CONFIG_MAPPING[model_args.model_type]() config = CONFIG_MAPPING[model_args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.") logger.warning("You are instantiating a new config instance from scratch.")
if model_args.tokenizer_name: if model_args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name) tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name, trust_remote_code=model_args.trust_remote_code
)
elif model_args.model_name_or_path: elif model_args.model_name_or_path:
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, trust_remote_code=model_args.trust_remote_code
)
else: else:
raise ValueError( raise ValueError(
"You are instantiating a new tokenizer from scratch. This is not supported by this script." "You are instantiating a new tokenizer from scratch. This is not supported by this script."
...@@ -495,12 +511,16 @@ def main(): ...@@ -495,12 +511,16 @@ def main():
with training_args.strategy.scope(): with training_args.strategy.scope():
# region Prepare model # region Prepare model
if checkpoint is not None: if checkpoint is not None:
model = TFAutoModelForMaskedLM.from_pretrained(checkpoint, config=config) model = TFAutoModelForMaskedLM.from_pretrained(
checkpoint, config=config, trust_remote_code=model_args.trust_remote_code
)
elif model_args.model_name_or_path: elif model_args.model_name_or_path:
model = TFAutoModelForMaskedLM.from_pretrained(model_args.model_name_or_path, config=config) model = TFAutoModelForMaskedLM.from_pretrained(
model_args.model_name_or_path, config=config, trust_remote_code=model_args.trust_remote_code
)
else: else:
logger.info("Training new model from scratch") logger.info("Training new model from scratch")
model = TFAutoModelForMaskedLM.from_config(config) model = TFAutoModelForMaskedLM.from_config(config, trust_remote_code=model_args.trust_remote_code)
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test. # on a small vocab and want a smaller embedding size, remove this test.
......
...@@ -162,6 +162,16 @@ class ModelArguments: ...@@ -162,6 +162,16 @@ class ModelArguments:
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`." "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`."
}, },
) )
trust_remote_code: bool = field(
default=False,
metadata={
"help": (
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
"should only be set to `True` for repositories you trust and in which you have read the code, as it will"
"execute code present on the Hub on your local machine."
)
},
)
@dataclass @dataclass
...@@ -349,6 +359,7 @@ def main(): ...@@ -349,6 +359,7 @@ def main():
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
token=model_args.token, token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
) )
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
...@@ -356,6 +367,7 @@ def main(): ...@@ -356,6 +367,7 @@ def main():
use_fast=model_args.use_fast_tokenizer, use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision, revision=model_args.model_revision,
token=model_args.token, token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
) )
# endregion # endregion
...@@ -442,6 +454,7 @@ def main(): ...@@ -442,6 +454,7 @@ def main():
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
token=model_args.token, token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
) )
num_replicas = training_args.strategy.num_replicas_in_sync num_replicas = training_args.strategy.num_replicas_in_sync
......
...@@ -93,6 +93,16 @@ class ModelArguments: ...@@ -93,6 +93,16 @@ class ModelArguments:
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`." "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`."
}, },
) )
trust_remote_code: bool = field(
default=False,
metadata={
"help": (
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
"should only be set to `True` for repositories you trust and in which you have read the code, as it will"
"execute code present on the Hub on your local machine."
)
},
)
@dataclass @dataclass
...@@ -352,6 +362,7 @@ def main(): ...@@ -352,6 +362,7 @@ def main():
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
token=model_args.token, token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
) )
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
...@@ -359,6 +370,7 @@ def main(): ...@@ -359,6 +370,7 @@ def main():
use_fast=True, use_fast=True,
revision=model_args.model_revision, revision=model_args.model_revision,
token=model_args.token, token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
) )
# endregion # endregion
...@@ -639,6 +651,7 @@ def main(): ...@@ -639,6 +651,7 @@ def main():
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
token=model_args.token, token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
) )
if training_args.do_train: if training_args.do_train:
training_dataset = model.prepare_tf_dataset( training_dataset = model.prepare_tf_dataset(
......
...@@ -115,6 +115,16 @@ class ModelArguments: ...@@ -115,6 +115,16 @@ class ModelArguments:
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`." "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`."
}, },
) )
trust_remote_code: bool = field(
default=False,
metadata={
"help": (
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
"should only be set to `True` for repositories you trust and in which you have read the code, as it will"
"execute code present on the Hub on your local machine."
)
},
)
@dataclass @dataclass
...@@ -402,6 +412,7 @@ def main(): ...@@ -402,6 +412,7 @@ def main():
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
token=model_args.token, token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
) )
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
...@@ -409,6 +420,7 @@ def main(): ...@@ -409,6 +420,7 @@ def main():
use_fast=model_args.use_fast_tokenizer, use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision, revision=model_args.model_revision,
token=model_args.token, token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
) )
prefix = data_args.source_prefix if data_args.source_prefix is not None else "" prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
...@@ -527,6 +539,7 @@ def main(): ...@@ -527,6 +539,7 @@ def main():
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
token=model_args.token, token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
) )
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
......
...@@ -180,6 +180,16 @@ class ModelArguments: ...@@ -180,6 +180,16 @@ class ModelArguments:
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`." "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`."
}, },
) )
trust_remote_code: bool = field(
default=False,
metadata={
"help": (
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
"should only be set to `True` for repositories you trust and in which you have read the code, as it will"
"execute code present on the Hub on your local machine."
)
},
)
# endregion # endregion
...@@ -298,6 +308,7 @@ def main(): ...@@ -298,6 +308,7 @@ def main():
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
token=model_args.token, token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
) )
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
...@@ -305,6 +316,7 @@ def main(): ...@@ -305,6 +316,7 @@ def main():
use_fast=model_args.use_fast_tokenizer, use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision, revision=model_args.model_revision,
token=model_args.token, token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
) )
# endregion # endregion
...@@ -388,6 +400,7 @@ def main(): ...@@ -388,6 +400,7 @@ def main():
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
token=model_args.token, token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
) )
# endregion # endregion
......
...@@ -186,6 +186,16 @@ class ModelArguments: ...@@ -186,6 +186,16 @@ class ModelArguments:
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`." "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`."
}, },
) )
trust_remote_code: bool = field(
default=False,
metadata={
"help": (
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
"should only be set to `True` for repositories you trust and in which you have read the code, as it will"
"execute code present on the Hub on your local machine."
)
},
)
# endregion # endregion
...@@ -315,6 +325,7 @@ def main(): ...@@ -315,6 +325,7 @@ def main():
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
token=model_args.token, token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
) )
else: else:
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
...@@ -322,12 +333,14 @@ def main(): ...@@ -322,12 +333,14 @@ def main():
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
token=model_args.token, token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
) )
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
token=model_args.token, token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
) )
# endregion # endregion
...@@ -416,6 +429,7 @@ def main(): ...@@ -416,6 +429,7 @@ def main():
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
token=model_args.token, token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
) )
# endregion # endregion
......
...@@ -91,6 +91,16 @@ class ModelArguments: ...@@ -91,6 +91,16 @@ class ModelArguments:
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`." "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`."
}, },
) )
trust_remote_code: bool = field(
default=False,
metadata={
"help": (
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
"should only be set to `True` for repositories you trust and in which you have read the code, as it will"
"execute code present on the Hub on your local machine."
)
},
)
@dataclass @dataclass
...@@ -304,9 +314,17 @@ def main(): ...@@ -304,9 +314,17 @@ def main():
# In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently # In distributed training, the .from_pretrained methods guarantee that only one local process can concurrently
# download model & vocab. # download model & vocab.
if model_args.config_name: if model_args.config_name:
config = AutoConfig.from_pretrained(model_args.config_name, num_labels=num_labels) config = AutoConfig.from_pretrained(
model_args.config_name,
num_labels=num_labels,
trust_remote_code=model_args.trust_remote_code,
)
elif model_args.model_name_or_path: elif model_args.model_name_or_path:
config = AutoConfig.from_pretrained(model_args.model_name_or_path, num_labels=num_labels) config = AutoConfig.from_pretrained(
model_args.model_name_or_path,
num_labels=num_labels,
trust_remote_code=model_args.trust_remote_code,
)
else: else:
config = CONFIG_MAPPING[model_args.model_type]() config = CONFIG_MAPPING[model_args.model_type]()
logger.warning("You are instantiating a new config instance from scratch.") logger.warning("You are instantiating a new config instance from scratch.")
...@@ -319,9 +337,18 @@ def main(): ...@@ -319,9 +337,18 @@ def main():
) )
if config.model_type in {"gpt2", "roberta"}: if config.model_type in {"gpt2", "roberta"}:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, use_fast=True, add_prefix_space=True) tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name_or_path,
use_fast=True,
add_prefix_space=True,
trust_remote_code=model_args.trust_remote_code,
)
else: else:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name_or_path, use_fast=True) tokenizer = AutoTokenizer.from_pretrained(
tokenizer_name_or_path,
use_fast=True,
trust_remote_code=model_args.trust_remote_code,
)
# endregion # endregion
# region Preprocessing the raw datasets # region Preprocessing the raw datasets
...@@ -392,10 +419,13 @@ def main(): ...@@ -392,10 +419,13 @@ def main():
model = TFAutoModelForTokenClassification.from_pretrained( model = TFAutoModelForTokenClassification.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
config=config, config=config,
trust_remote_code=model_args.trust_remote_code,
) )
else: else:
logger.info("Training new model from scratch") logger.info("Training new model from scratch")
model = TFAutoModelForTokenClassification.from_config(config) model = TFAutoModelForTokenClassification.from_config(
config, trust_remote_code=model_args.trust_remote_code
)
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
# on a small vocab and want a smaller embedding size, remove this test. # on a small vocab and want a smaller embedding size, remove this test.
......
...@@ -109,6 +109,16 @@ class ModelArguments: ...@@ -109,6 +109,16 @@ class ModelArguments:
"help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`." "help": "The `use_auth_token` argument is deprecated and will be removed in v4.34. Please use `token`."
}, },
) )
trust_remote_code: bool = field(
default=False,
metadata={
"help": (
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
"should only be set to `True` for repositories you trust and in which you have read the code, as it will"
"execute code present on the Hub on your local machine."
)
},
)
@dataclass @dataclass
...@@ -366,6 +376,7 @@ def main(): ...@@ -366,6 +376,7 @@ def main():
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
token=model_args.token, token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
) )
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
...@@ -373,6 +384,7 @@ def main(): ...@@ -373,6 +384,7 @@ def main():
use_fast=model_args.use_fast_tokenizer, use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision, revision=model_args.model_revision,
token=model_args.token, token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
) )
prefix = data_args.source_prefix if data_args.source_prefix is not None else "" prefix = data_args.source_prefix if data_args.source_prefix is not None else ""
...@@ -480,6 +492,7 @@ def main(): ...@@ -480,6 +492,7 @@ def main():
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
token=model_args.token, token=model_args.token,
trust_remote_code=model_args.trust_remote_code,
) )
# We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch # We resize the embeddings only when necessary to avoid index errors. If you are creating a model from scratch
......
...@@ -122,6 +122,16 @@ class ModelArguments: ...@@ -122,6 +122,16 @@ class ModelArguments:
"with private models)." "with private models)."
}, },
) )
trust_remote_code: bool = field(
default=False,
metadata={
"help": (
"Whether or not to allow for custom models defined on the Hub in their own modeling files. This option"
"should only be set to `True` for repositories you trust and in which you have read the code, as it will"
"execute code present on the Hub on your local machine."
)
},
)
{% endif %} {% endif %}
...@@ -290,6 +300,7 @@ def main(): ...@@ -290,6 +300,7 @@ def main():
"cache_dir": model_args.cache_dir, "cache_dir": model_args.cache_dir,
"revision": model_args.model_revision, "revision": model_args.model_revision,
"token": True if model_args.token else None, "token": True if model_args.token else None,
"trust_remote_code": model_args.trust_remote_code,
} }
if model_args.config_name: if model_args.config_name:
config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs) config = AutoConfig.from_pretrained(model_args.config_name, **config_kwargs)
...@@ -304,6 +315,7 @@ def main(): ...@@ -304,6 +315,7 @@ def main():
"use_fast": model_args.use_fast_tokenizer, "use_fast": model_args.use_fast_tokenizer,
"revision": model_args.model_revision, "revision": model_args.model_revision,
"token": True if model_args.token else None, "token": True if model_args.token else None,
"trust_remote_code": model_args.trust_remote_code,
} }
if model_args.tokenizer_name: if model_args.tokenizer_name:
tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs) tokenizer = AutoTokenizer.from_pretrained(model_args.tokenizer_name, **tokenizer_kwargs)
...@@ -323,6 +335,7 @@ def main(): ...@@ -323,6 +335,7 @@ def main():
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
token=True if model_args.token else None, token=True if model_args.token else None,
trust_remote_code=model_args.trust_remote_code,
) )
else: else:
logger.info("Training new model from scratch") logger.info("Training new model from scratch")
...@@ -337,6 +350,7 @@ def main(): ...@@ -337,6 +350,7 @@ def main():
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
token=True if model_args.token else None, token=True if model_args.token else None,
trust_remote_code=model_args.trust_remote_code,
) )
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
...@@ -344,6 +358,7 @@ def main(): ...@@ -344,6 +358,7 @@ def main():
use_fast=model_args.use_fast_tokenizer, use_fast=model_args.use_fast_tokenizer,
revision=model_args.model_revision, revision=model_args.model_revision,
token=True if model_args.token else None, token=True if model_args.token else None,
trust_remote_code=model_args.trust_remote_code,
) )
model = AutoModelForSequenceClassification.from_pretrained( model = AutoModelForSequenceClassification.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
...@@ -352,6 +367,7 @@ def main(): ...@@ -352,6 +367,7 @@ def main():
cache_dir=model_args.cache_dir, cache_dir=model_args.cache_dir,
revision=model_args.model_revision, revision=model_args.model_revision,
token=True if model_args.token else None, token=True if model_args.token else None,
trust_remote_code=model_args.trust_remote_code,
) )
{% endif %} {% endif %}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment