Unverified Commit 5a673bf8 authored by 정한 Rycont's avatar 정한 Rycont Committed by GitHub
Browse files

Pass `model_init_kwargs` to `check_and_get_model_type` function (#232)

parent 4d49ae1f
......@@ -20,8 +20,8 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
"qwen": QwenAWQForCausalLM
}
def check_and_get_model_type(model_dir, trust_remote_code=True):
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=trust_remote_code)
def check_and_get_model_type(model_dir, trust_remote_code=True, **model_init_kwargs):
config = AutoConfig.from_pretrained(model_dir, trust_remote_code=trust_remote_code, **model_init_kwargs)
if config.model_type not in AWQ_CAUSAL_LM_MODEL_MAP.keys():
raise TypeError(f"{config.model_type} isn't supported yet.")
model_type = config.model_type
......@@ -35,7 +35,7 @@ class AutoAWQForCausalLM:
@classmethod
def from_pretrained(self, model_path, trust_remote_code=True, safetensors=False,
device_map=None, **model_init_kwargs) -> BaseAWQForCausalLM:
model_type = check_and_get_model_type(model_path, trust_remote_code)
model_type = check_and_get_model_type(model_path, trust_remote_code, **model_init_kwargs)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_pretrained(
model_path, model_type, trust_remote_code=trust_remote_code, safetensors=safetensors,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment