"vscode:/vscode.git/clone" did not exist on "172260f9cd6641b18455d228d0760ca1c81a44f3"
Unverified Commit 1a8f995c authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

remove cache configs in model definitions (#4031)

parent a3ab768a
...@@ -359,7 +359,6 @@ class Grok1ForCausalLM(nn.Module): ...@@ -359,7 +359,6 @@ class Grok1ForCausalLM(nn.Module):
self, self,
config: PretrainedConfig, config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None: ) -> None:
super().__init__() super().__init__()
self.config = config self.config = config
......
...@@ -106,7 +106,6 @@ class LlamaForCausalLMEagle(LlamaForCausalLM): ...@@ -106,7 +106,6 @@ class LlamaForCausalLMEagle(LlamaForCausalLM):
self, self,
config: LlamaConfig, config: LlamaConfig,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None: ) -> None:
nn.Module.__init__(self) nn.Module.__init__(self)
self.config = config self.config = config
......
...@@ -107,7 +107,6 @@ class Qwen2ForCausalLMEagle(Qwen2ForCausalLM): ...@@ -107,7 +107,6 @@ class Qwen2ForCausalLMEagle(Qwen2ForCausalLM):
self, self,
config: Qwen2Config, config: Qwen2Config,
quant_config: Optional[QuantizationConfig] = None, quant_config: Optional[QuantizationConfig] = None,
cache_config=None,
) -> None: ) -> None:
nn.Module.__init__(self) nn.Module.__init__(self)
self.config = config self.config = config
......
...@@ -159,45 +159,6 @@ def call_generate_guidance( ...@@ -159,45 +159,6 @@ def call_generate_guidance(
return rets if n > 1 else rets[0] return rets if n > 1 else rets[0]
async def call_generate_lmql(
prompt, temperature, max_tokens, stop=None, n=1, max_len=4096, model=None, **kwargs
):
assert model is not None
import lmql
if stop != None:
@lmql.query(model=model)
async def program(question, max_tokens, stop):
'''lmql
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens and STOPS_AT(ANSWER, stop)
return ANSWER
'''
else:
@lmql.query(model=model)
async def program(question, max_tokens):
'''lmql
"""{question}[ANSWER]""" where len(TOKENS(ANSWER)) < max_tokens
return ANSWER
'''
tasks = [
program(
question=prompt,
temperature=temperature,
max_tokens=max_tokens,
stop=stop,
max_len=max_len,
**kwargs,
)
for _ in range(n)
]
rets = await asyncio.gather(*tasks)
return rets if n > 1 else rets[0]
def call_select_lightllm(context, choices, url=None): def call_select_lightllm(context, choices, url=None):
assert url is not None assert url is not None
...@@ -247,23 +208,6 @@ def call_select_guidance(context, choices, model=None): ...@@ -247,23 +208,6 @@ def call_select_guidance(context, choices, model=None):
return choices.index(out["answer"]) return choices.index(out["answer"])
async def call_select_lmql(context, choices, temperature=0, max_len=4096, model=None):
assert model is not None
import lmql
@lmql.query(model=model)
async def program(ctx, choices):
'''lmql
"""{ctx}[ANSWER]""" where ANSWER in set(choices)
return ANSWER
'''
answer = await program(
ctx=context, choices=choices, temperature=temperature, max_len=max_len
)
return choices.index(answer)
def add_common_other_args_and_parse(parser: argparse.ArgumentParser): def add_common_other_args_and_parse(parser: argparse.ArgumentParser):
parser.add_argument("--parallel", type=int, default=64) parser.add_argument("--parallel", type=int, default=64)
parser.add_argument("--host", type=str, default="http://127.0.0.1") parser.add_argument("--host", type=str, default="http://127.0.0.1")
...@@ -278,7 +222,6 @@ def add_common_other_args_and_parse(parser: argparse.ArgumentParser): ...@@ -278,7 +222,6 @@ def add_common_other_args_and_parse(parser: argparse.ArgumentParser):
"lightllm", "lightllm",
"gserver", "gserver",
"guidance", "guidance",
"lmql",
"srt-raw", "srt-raw",
"llama.cpp", "llama.cpp",
], ],
...@@ -295,7 +238,6 @@ def add_common_other_args_and_parse(parser: argparse.ArgumentParser): ...@@ -295,7 +238,6 @@ def add_common_other_args_and_parse(parser: argparse.ArgumentParser):
"vllm": 21000, "vllm": 21000,
"outlines": 21000, "outlines": 21000,
"lightllm": 22000, "lightllm": 22000,
"lmql": 23000,
"srt-raw": 30000, "srt-raw": 30000,
"gserver": 9988, "gserver": 9988,
} }
...@@ -343,11 +285,6 @@ def _get_call_generate(args: argparse.Namespace): ...@@ -343,11 +285,6 @@ def _get_call_generate(args: argparse.Namespace):
call_generate = partial(call_generate_guidance, model=model) call_generate = partial(call_generate_guidance, model=model)
call_generate("Hello,", 1.0, 8, ".") call_generate("Hello,", 1.0, 8, ".")
return call_generate return call_generate
elif args.backend == "lmql":
import lmql
model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
return partial(call_generate_lmql, model=model)
else: else:
raise ValueError(f"Invalid backend: {args.backend}") raise ValueError(f"Invalid backend: {args.backend}")
...@@ -365,12 +302,6 @@ def _get_call_select(args: argparse.Namespace): ...@@ -365,12 +302,6 @@ def _get_call_select(args: argparse.Namespace):
call_select("Hello,", ["world", "earth"]) call_select("Hello,", ["world", "earth"])
return call_select return call_select
elif args.backend == "lmql":
import lmql
model = lmql.model(args.model_path, endpoint=f"{args.host}:{args.port}")
return partial(call_select_lmql, model=model)
else: else:
raise ValueError(f"Invalid backend: {args.backend}") raise ValueError(f"Invalid backend: {args.backend}")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment