Unverified Commit 3b8e1cb7 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #260 from InfiniTensor/issue/259

issue/259 - add attn backend option to inference server
parents dfec9d89 91cd2992
......@@ -55,6 +55,7 @@ class EngineConfig:
top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter.
enable_graph: Whether to enable graph compiling.
attn_backend: Attention backend to use ('default', 'flash-attn').
"""
model_path: str
......@@ -71,6 +72,7 @@ class EngineConfig:
top_p: float = 0.8
top_k: int = 1
enable_graph: bool = False
attn_backend: str = "default"
class LLMEngine:
......@@ -88,6 +90,7 @@ class LLMEngine:
device=self.device,
distributed_config=DistConfig(config.tensor_parallel_size),
enable_graph_compiling=config.enable_graph,
attention_backend=config.attn_backend,
)
# Load model weights
......@@ -383,6 +386,7 @@ class LLM:
top_p: float = 0.8,
top_k: int = 1,
enable_graph: bool = False,
attn_backend: str = "default",
):
"""Initialize LLM.
......@@ -401,6 +405,7 @@ class LLM:
top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter.
enable_graph: Whether to enable graph compiling.
attn_backend: Attention backend to use ('default', 'flash-attn').
"""
config = EngineConfig(
model_path=model_path,
......@@ -417,6 +422,7 @@ class LLM:
top_p=top_p,
top_k=top_k,
enable_graph=enable_graph,
attn_backend=attn_backend,
)
self.engine = LLMEngine(config)
self.config = config
......@@ -536,6 +542,7 @@ class AsyncLLMEngine:
top_p: float = 0.8,
top_k: int = 1,
enable_graph: bool = False,
attn_backend: str = "default",
):
"""Initialize AsyncLLMEngine.
......@@ -554,6 +561,7 @@ class AsyncLLMEngine:
top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter.
enable_graph: Whether to enable graph compiling.
attn_backend: Attention backend to use ('default', 'flash-attn').
"""
config = EngineConfig(
model_path=model_path,
......@@ -570,6 +578,7 @@ class AsyncLLMEngine:
top_p=top_p,
top_k=top_k,
enable_graph=enable_graph,
attn_backend=attn_backend,
)
self.engine = LLMEngine(config)
self.config = config
......
......@@ -108,6 +108,7 @@ class InferenceServer:
host: str = "0.0.0.0",
port: int = 8000,
enable_graph: bool = False,
attn_backend: str = "default",
):
"""Initialize inference server.
......@@ -128,6 +129,7 @@ class InferenceServer:
host: Server host address.
port: Server port number.
enable_graph: Whether to enable graph compiling.
attn_backend: Attention backend to use ('default', 'flash-attn').
"""
self.model_path = model_path
# vLLM-like served model id: directory name of model_path
......@@ -147,6 +149,7 @@ class InferenceServer:
self.host = host
self.port = port
self.enable_graph = enable_graph
self.attn_backend = attn_backend
self.engine: AsyncLLMEngine = None
......@@ -177,6 +180,7 @@ class InferenceServer:
top_p=self.top_p,
top_k=self.top_k,
enable_graph=self.enable_graph,
attn_backend=self.attn_backend,
)
self.engine.start()
logger.info(f"Engine initialized with model at {self.model_path}")
......@@ -613,6 +617,13 @@ def parse_args():
action="store_true",
help="Enable graph compiling",
)
parser.add_argument(
"--attn",
type=str,
default="default",
choices=["default", "flash-attn"],
help="Attention backend to use: 'default' or 'flash-attn'",
)
parser.add_argument(
"--log_level",
type=str,
......@@ -655,7 +666,7 @@ def main():
"Example: python infinilm.server.inference_server --nvidia --model_path=/data/shared/models/9G7B_MHA/ "
"--max_tokens=100 --max_batch_size=32 --tp=1 --temperature=1.0 --top_p=0.8 --top_k=1"
"\n"
"Optional: --enable-paged-attn --enable-graph"
"Optional: --enable-paged-attn --enable-graph --attn=default"
)
sys.exit(1)
......@@ -676,6 +687,7 @@ def main():
host=args.host,
port=args.port,
enable_graph=args.enable_graph,
attn_backend=args.attn,
)
server.start()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment