Commit fd5b8c88 authored by Casper Hansen's avatar Casper Hansen
Browse files

FP16 weights example works

parent 0834fb46
from typing import List from typing import List
from awq.models import * from awq.models import *
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.models.falcon.modeling_falcon import FalconForCausalLM
class BasePrompter: class BasePrompter:
def __init__(self, system_inst, role1, role2, sen_spliter = "\n", qa_spliter = "\n", decorator: List[str] = None): def __init__(self, system_inst, role1, role2, sen_spliter = "\n", qa_spliter = "\n", decorator: List[str] = None):
...@@ -127,14 +129,14 @@ class MPTChatPrompter(BasePrompter): ...@@ -127,14 +129,14 @@ class MPTChatPrompter(BasePrompter):
def get_prompter(model, model_path = ""): def get_prompter(model, model_path = ""):
if isinstance(model, LlamaAWQForCausalLM): if isinstance(model, LlamaAWQForCausalLM) or isinstance(model, LlamaForCausalLM):
if "vicuna" in model_path: if "vicuna" in model_path:
return VicunaPrompter() return VicunaPrompter()
else: else:
return Llama2Prompter() return Llama2Prompter()
elif isinstance(model, FalconAWQForCausalLM): elif isinstance(model, FalconAWQForCausalLM) or isinstance(model, FalconForCausalLM):
return FalconSimplePrompter() return FalconSimplePrompter()
elif isinstance(model, MptAWQForCausalLM): elif isinstance(model, MptAWQForCausalLM) or "mpt" in str(model.__class__).lower():
if "mpt" and "chat" in model_path: if "mpt" and "chat" in model_path:
return MPTChatPrompter() return MPTChatPrompter()
else: else:
...@@ -143,14 +145,15 @@ def get_prompter(model, model_path = ""): ...@@ -143,14 +145,15 @@ def get_prompter(model, model_path = ""):
raise ValueError(f"model type {model.model_type} is not supported") raise ValueError(f"model type {model.model_type} is not supported")
def get_stop_token_ids(model, model_path = ""): def get_stop_token_ids(model, model_path = ""):
if isinstance(model, LlamaAWQForCausalLM): if isinstance(model, LlamaAWQForCausalLM) or isinstance(model, LlamaForCausalLM):
return [] return []
elif isinstance(model, FalconAWQForCausalLM): elif isinstance(model, FalconAWQForCausalLM) or isinstance(model, FalconForCausalLM):
return [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11] return [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
elif isinstance(model, MptAWQForCausalLM): elif isinstance(model, MptAWQForCausalLM) or "mpt" in str(model.__class__).lower():
if "mpt" and "chat" in model_path: if "mpt" and "chat" in model_path:
return [50278, 0] return [50278, 0]
else: else:
return [] return []
else: else:
raise ValueError(f"model type {model.model_type} is not supported") model_type = str(model.__class__).lower()
raise ValueError(f"model type {model_type} is not supported")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment