Unverified Commit 5a673bf8 authored by 정한 Rycont's avatar 정한 Rycont Committed by GitHub
Browse files

Pass `model_init_kwargs` to `check_and_get_model_type` function (#232)

parent 4d49ae1f
...@@ -20,8 +20,8 @@ AWQ_CAUSAL_LM_MODEL_MAP = { ...@@ -20,8 +20,8 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
"qwen": QwenAWQForCausalLM "qwen": QwenAWQForCausalLM
} }
def check_and_get_model_type(model_dir, trust_remote_code=True): def check_and_get_model_type(model_dir, trust_remote_code=True, **model_init_kwargs):
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=trust_remote_code) config = AutoConfig.from_pretrained(model_dir, trust_remote_code=trust_remote_code, **model_init_kwargs)
if config.model_type not in AWQ_CAUSAL_LM_MODEL_MAP.keys(): if config.model_type not in AWQ_CAUSAL_LM_MODEL_MAP.keys():
raise TypeError(f"{config.model_type} isn't supported yet.") raise TypeError(f"{config.model_type} isn't supported yet.")
model_type = config.model_type model_type = config.model_type
...@@ -35,7 +35,7 @@ class AutoAWQForCausalLM: ...@@ -35,7 +35,7 @@ class AutoAWQForCausalLM:
@classmethod @classmethod
def from_pretrained(self, model_path, trust_remote_code=True, safetensors=False, def from_pretrained(self, model_path, trust_remote_code=True, safetensors=False,
device_map=None, **model_init_kwargs) -> BaseAWQForCausalLM: device_map=None, **model_init_kwargs) -> BaseAWQForCausalLM:
model_type = check_and_get_model_type(model_path, trust_remote_code) model_type = check_and_get_model_type(model_path, trust_remote_code, **model_init_kwargs)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained( return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
model_path, model_type, trust_remote_code=trust_remote_code, safetensors=safetensors, model_path, model_type, trust_remote_code=trust_remote_code, safetensors=safetensors,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment