Unverified Commit ab019eea authored by Jasmond L's avatar Jasmond L Committed by GitHub
Browse files

Add Model Revision Support (#1014)


Co-authored-by: default avatarJasmond Loh <Jasmond.Loh@hotmail.com>
Co-authored-by: default avatarZhuohan Li <zhuohan123@gmail.com>
parent 9841d48a
...@@ -38,6 +38,9 @@ class ModelConfig: ...@@ -38,6 +38,9 @@ class ModelConfig:
will use FP16 precision for FP32 and FP16 models, and BF16 precision will use FP16 precision for FP32 and FP16 models, and BF16 precision
for BF16 models. for BF16 models.
seed: Random seed for reproducibility. seed: Random seed for reproducibility.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id. If unspecified, will use the default
version.
max_model_len: Maximum length of a sequence (including prompt and max_model_len: Maximum length of a sequence (including prompt and
output). If None, will be derived from the model. output). If None, will be derived from the model.
""" """
...@@ -52,6 +55,7 @@ class ModelConfig: ...@@ -52,6 +55,7 @@ class ModelConfig:
load_format: str, load_format: str,
dtype: str, dtype: str,
seed: int, seed: int,
revision: Optional[str],
max_model_len: Optional[int] = None, max_model_len: Optional[int] = None,
) -> None: ) -> None:
self.model = model self.model = model
...@@ -61,8 +65,9 @@ class ModelConfig: ...@@ -61,8 +65,9 @@ class ModelConfig:
self.download_dir = download_dir self.download_dir = download_dir
self.load_format = load_format self.load_format = load_format
self.seed = seed self.seed = seed
self.revision = revision
self.hf_config = get_config(model, trust_remote_code) self.hf_config = get_config(model, trust_remote_code, revision)
self.dtype = _get_and_verify_dtype(self.hf_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self._verify_load_format() self._verify_load_format()
self._verify_tokenizer_mode() self._verify_tokenizer_mode()
......
...@@ -28,6 +28,7 @@ class EngineArgs: ...@@ -28,6 +28,7 @@ class EngineArgs:
max_num_batched_tokens: int = 2560 max_num_batched_tokens: int = 2560
max_num_seqs: int = 256 max_num_seqs: int = 256
disable_log_stats: bool = False disable_log_stats: bool = False
revision: Optional[str] = None
def __post_init__(self): def __post_init__(self):
if self.tokenizer is None: if self.tokenizer is None:
...@@ -49,6 +50,13 @@ class EngineArgs: ...@@ -49,6 +50,13 @@ class EngineArgs:
type=str, type=str,
default=EngineArgs.tokenizer, default=EngineArgs.tokenizer,
help='name or path of the huggingface tokenizer to use') help='name or path of the huggingface tokenizer to use')
parser.add_argument(
'--revision',
type=str,
default=None,
help='the specific model version to use. It can be a branch '
'name, a tag name, or a commit id. If unspecified, will use '
'the default version.')
parser.add_argument('--tokenizer-mode', parser.add_argument('--tokenizer-mode',
type=str, type=str,
default=EngineArgs.tokenizer_mode, default=EngineArgs.tokenizer_mode,
...@@ -159,7 +167,8 @@ class EngineArgs: ...@@ -159,7 +167,8 @@ class EngineArgs:
model_config = ModelConfig(self.model, self.tokenizer, model_config = ModelConfig(self.model, self.tokenizer,
self.tokenizer_mode, self.trust_remote_code, self.tokenizer_mode, self.trust_remote_code,
self.download_dir, self.load_format, self.download_dir, self.load_format,
self.dtype, self.seed, self.max_model_len) self.dtype, self.seed, self.revision,
self.max_model_len)
cache_config = CacheConfig(self.block_size, cache_config = CacheConfig(self.block_size,
self.gpu_memory_utilization, self.gpu_memory_utilization,
self.swap_space) self.swap_space)
......
...@@ -74,6 +74,7 @@ class LLMEngine: ...@@ -74,6 +74,7 @@ class LLMEngine:
f"model={model_config.model!r}, " f"model={model_config.model!r}, "
f"tokenizer={model_config.tokenizer!r}, " f"tokenizer={model_config.tokenizer!r}, "
f"tokenizer_mode={model_config.tokenizer_mode}, " f"tokenizer_mode={model_config.tokenizer_mode}, "
f"revision={model_config.revision}, "
f"trust_remote_code={model_config.trust_remote_code}, " f"trust_remote_code={model_config.trust_remote_code}, "
f"dtype={model_config.dtype}, " f"dtype={model_config.dtype}, "
f"download_dir={model_config.download_dir!r}, " f"download_dir={model_config.download_dir!r}, "
...@@ -92,7 +93,8 @@ class LLMEngine: ...@@ -92,7 +93,8 @@ class LLMEngine:
self.tokenizer = get_tokenizer( self.tokenizer = get_tokenizer(
model_config.tokenizer, model_config.tokenizer,
tokenizer_mode=model_config.tokenizer_mode, tokenizer_mode=model_config.tokenizer_mode,
trust_remote_code=model_config.trust_remote_code) trust_remote_code=model_config.trust_remote_code,
revision=model_config.revision)
self.seq_counter = Counter() self.seq_counter = Counter()
# Create the parallel GPU workers. # Create the parallel GPU workers.
......
...@@ -38,6 +38,8 @@ class LLM: ...@@ -38,6 +38,8 @@ class LLM:
However, if the `torch_dtype` in the config is `float32`, we will However, if the `torch_dtype` in the config is `float32`, we will
use `float16` instead. use `float16` instead.
seed: The seed to initialize the random number generator for sampling. seed: The seed to initialize the random number generator for sampling.
revision: The specific model version to use. It can be a branch name,
a tag name, or a commit id.
""" """
def __init__( def __init__(
......
...@@ -64,6 +64,6 @@ def get_model(model_config: ModelConfig) -> nn.Module: ...@@ -64,6 +64,6 @@ def get_model(model_config: ModelConfig) -> nn.Module:
else: else:
# Load the weights from the cached or downloaded files. # Load the weights from the cached or downloaded files.
model.load_weights(model_config.model, model_config.download_dir, model.load_weights(model_config.model, model_config.download_dir,
model_config.load_format) model_config.load_format, model_config.revision)
model = model.cuda() model = model.cuda()
return model.eval() return model.eval()
...@@ -288,7 +288,8 @@ class AquilaForCausalLM(nn.Module): ...@@ -288,7 +288,8 @@ class AquilaForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size) q_proj_shard_size = (self.config.hidden_size // tp_size)
...@@ -305,7 +306,7 @@ class AquilaForCausalLM(nn.Module): ...@@ -305,7 +306,7 @@ class AquilaForCausalLM(nn.Module):
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
......
...@@ -303,13 +303,14 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -303,13 +303,14 @@ class BaiChuanBaseForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tp_world_size = get_tensor_model_parallel_world_size() tp_world_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
......
...@@ -279,11 +279,12 @@ class BloomForCausalLM(nn.Module): ...@@ -279,11 +279,12 @@ class BloomForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if name == "lm_head.weight": if name == "lm_head.weight":
# Since hidden_states are parallelized, we need to # Since hidden_states are parallelized, we need to
# load lm_head.weight in parallel. # load lm_head.weight in parallel.
......
...@@ -420,7 +420,8 @@ class FalconForCausalLM(nn.Module): ...@@ -420,7 +420,8 @@ class FalconForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tp_size = (get_tensor_model_parallel_world_size()) tp_size = (get_tensor_model_parallel_world_size())
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
...@@ -452,7 +453,7 @@ class FalconForCausalLM(nn.Module): ...@@ -452,7 +453,7 @@ class FalconForCausalLM(nn.Module):
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "query_key_value" in name: if "query_key_value" in name:
loaded_weight = convert_pyslice_to_tensor(loaded_weight) loaded_weight = convert_pyslice_to_tensor(loaded_weight)
loaded_weight_size = loaded_weight.size() loaded_weight_size = loaded_weight.size()
......
...@@ -231,14 +231,15 @@ class GPT2LMHeadModel(nn.Module): ...@@ -231,14 +231,15 @@ class GPT2LMHeadModel(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tensor_model_parallel_world_size = ( tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size()) get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "lm_head.weight" in name: if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final # GPT-2 ties the weights of the embedding layer and the final
# linear layer. # linear layer.
......
...@@ -259,14 +259,15 @@ class GPTBigCodeForCausalLM(nn.Module): ...@@ -259,14 +259,15 @@ class GPTBigCodeForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tensor_model_parallel_world_size = ( tensor_model_parallel_world_size = (
get_tensor_model_parallel_world_size()) get_tensor_model_parallel_world_size())
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "lm_head.weight" in name: if "lm_head.weight" in name:
# GPT-2 ties the weights of the embedding layer and the final # GPT-2 ties the weights of the embedding layer and the final
# linear layer. # linear layer.
......
...@@ -222,11 +222,12 @@ class GPTJForCausalLM(nn.Module): ...@@ -222,11 +222,12 @@ class GPTJForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "attn.bias" in name or "attn.masked_bias" in name: if "attn.bias" in name or "attn.masked_bias" in name:
continue continue
......
...@@ -231,11 +231,12 @@ class GPTNeoXForCausalLM(nn.Module): ...@@ -231,11 +231,12 @@ class GPTNeoXForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if ("attention.bias" in name or "attention.masked_bias" in name if ("attention.bias" in name or "attention.masked_bias" in name
or "rotary_emb.inv_freq" in name): or "rotary_emb.inv_freq" in name):
continue continue
......
...@@ -233,12 +233,13 @@ class InternLMForCausalLM(nn.Module): ...@@ -233,12 +233,13 @@ class InternLMForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
......
...@@ -271,7 +271,8 @@ class LlamaForCausalLM(nn.Module): ...@@ -271,7 +271,8 @@ class LlamaForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
q_proj_shard_size = (self.config.hidden_size // tp_size) q_proj_shard_size = (self.config.hidden_size // tp_size)
...@@ -288,7 +289,7 @@ class LlamaForCausalLM(nn.Module): ...@@ -288,7 +289,7 @@ class LlamaForCausalLM(nn.Module):
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
......
...@@ -244,12 +244,13 @@ class MPTForCausalLM(nn.Module): ...@@ -244,12 +244,13 @@ class MPTForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tp_world_size = get_tensor_model_parallel_world_size() tp_world_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "Wqkv" in name: if "Wqkv" in name:
# NOTE(woosuk): MPT's fused QKV has the shape of # NOTE(woosuk): MPT's fused QKV has the shape of
# [3 * num_heads * head_size, hidden_size]. # [3 * num_heads * head_size, hidden_size].
......
...@@ -297,12 +297,13 @@ class OPTForCausalLM(nn.Module): ...@@ -297,12 +297,13 @@ class OPTForCausalLM(nn.Module):
def load_weights(self, def load_weights(self,
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto"): load_format: str = "auto",
revision: Optional[str] = None):
tensor_model_parallel_rank = get_tensor_model_parallel_rank() tensor_model_parallel_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "lm_head.weight" in name: if "lm_head.weight" in name:
continue continue
......
...@@ -251,13 +251,14 @@ class QWenLMHeadModel(nn.Module): ...@@ -251,13 +251,14 @@ class QWenLMHeadModel(nn.Module):
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None,
): ):
tp_world_size = get_tensor_model_parallel_world_size() tp_world_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
state_dict = self.state_dict() state_dict = self.state_dict()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format): model_name_or_path, cache_dir, load_format, revision):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
......
...@@ -83,6 +83,7 @@ def prepare_hf_model_weights( ...@@ -83,6 +83,7 @@ def prepare_hf_model_weights(
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
use_safetensors: bool = False, use_safetensors: bool = False,
fall_back_to_pt: bool = True, fall_back_to_pt: bool = True,
revision: Optional[str] = None,
): ):
# Download model weights from huggingface. # Download model weights from huggingface.
is_local = os.path.isdir(model_name_or_path) is_local = os.path.isdir(model_name_or_path)
...@@ -94,7 +95,8 @@ def prepare_hf_model_weights( ...@@ -94,7 +95,8 @@ def prepare_hf_model_weights(
hf_folder = snapshot_download(model_name_or_path, hf_folder = snapshot_download(model_name_or_path,
allow_patterns=allow_patterns, allow_patterns=allow_patterns,
cache_dir=cache_dir, cache_dir=cache_dir,
tqdm_class=Disabledtqdm) tqdm_class=Disabledtqdm,
revision=revision)
else: else:
hf_folder = model_name_or_path hf_folder = model_name_or_path
hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns)) hf_weights_files = glob.glob(os.path.join(hf_folder, allow_patterns))
...@@ -107,7 +109,8 @@ def prepare_hf_model_weights( ...@@ -107,7 +109,8 @@ def prepare_hf_model_weights(
return prepare_hf_model_weights(model_name_or_path, return prepare_hf_model_weights(model_name_or_path,
cache_dir=cache_dir, cache_dir=cache_dir,
use_safetensors=False, use_safetensors=False,
fall_back_to_pt=False) fall_back_to_pt=False,
revision=revision)
if len(hf_weights_files) == 0: if len(hf_weights_files) == 0:
raise RuntimeError( raise RuntimeError(
...@@ -120,6 +123,7 @@ def hf_model_weights_iterator( ...@@ -120,6 +123,7 @@ def hf_model_weights_iterator(
model_name_or_path: str, model_name_or_path: str,
cache_dir: Optional[str] = None, cache_dir: Optional[str] = None,
load_format: str = "auto", load_format: str = "auto",
revision: Optional[str] = None,
) -> Iterator[Tuple[str, torch.Tensor]]: ) -> Iterator[Tuple[str, torch.Tensor]]:
use_safetensors = False use_safetensors = False
use_np_cache = False use_np_cache = False
...@@ -140,7 +144,8 @@ def hf_model_weights_iterator( ...@@ -140,7 +144,8 @@ def hf_model_weights_iterator(
model_name_or_path, model_name_or_path,
cache_dir=cache_dir, cache_dir=cache_dir,
use_safetensors=use_safetensors, use_safetensors=use_safetensors,
fall_back_to_pt=fall_back_to_pt) fall_back_to_pt=fall_back_to_pt,
revision=revision)
if use_np_cache: if use_np_cache:
# Currently np_cache only support *.bin checkpoints # Currently np_cache only support *.bin checkpoints
......
from typing import Optional
from transformers import AutoConfig, PretrainedConfig from transformers import AutoConfig, PretrainedConfig
from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import
...@@ -12,10 +14,12 @@ _CONFIG_REGISTRY = { ...@@ -12,10 +14,12 @@ _CONFIG_REGISTRY = {
} }
def get_config(model: str, trust_remote_code: bool) -> PretrainedConfig: def get_config(model: str,
trust_remote_code: bool,
revision: Optional[str] = None) -> PretrainedConfig:
try: try:
config = AutoConfig.from_pretrained( config = AutoConfig.from_pretrained(
model, trust_remote_code=trust_remote_code) model, trust_remote_code=trust_remote_code, revision=revision)
except ValueError as e: except ValueError as e:
if (not trust_remote_code and if (not trust_remote_code and
"requires you to execute the configuration file" in str(e)): "requires you to execute the configuration file" in str(e)):
...@@ -29,5 +33,5 @@ def get_config(model: str, trust_remote_code: bool) -> PretrainedConfig: ...@@ -29,5 +33,5 @@ def get_config(model: str, trust_remote_code: bool) -> PretrainedConfig:
raise e raise e
if config.model_type in _CONFIG_REGISTRY: if config.model_type in _CONFIG_REGISTRY:
config_class = _CONFIG_REGISTRY[config.model_type] config_class = _CONFIG_REGISTRY[config.model_type]
config = config_class.from_pretrained(model) config = config_class.from_pretrained(model, revision=revision)
return config return config
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment