Unverified Commit 3b8e1cb7 authored by thatPepe's avatar thatPepe Committed by GitHub
Browse files

Merge pull request #260 from InfiniTensor/issue/259

issue/259 - add attn backend option to inference server
parents dfec9d89 91cd2992
...@@ -55,6 +55,7 @@ class EngineConfig: ...@@ -55,6 +55,7 @@ class EngineConfig:
top_p: Default top-p sampling parameter. top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter. top_k: Default top-k sampling parameter.
enable_graph: Whether to enable graph compiling. enable_graph: Whether to enable graph compiling.
attn_backend: Attention backend to use ('default', 'flash-attn').
""" """
model_path: str model_path: str
...@@ -71,6 +72,7 @@ class EngineConfig: ...@@ -71,6 +72,7 @@ class EngineConfig:
top_p: float = 0.8 top_p: float = 0.8
top_k: int = 1 top_k: int = 1
enable_graph: bool = False enable_graph: bool = False
attn_backend: str = "default"
class LLMEngine: class LLMEngine:
...@@ -88,6 +90,7 @@ class LLMEngine: ...@@ -88,6 +90,7 @@ class LLMEngine:
device=self.device, device=self.device,
distributed_config=DistConfig(config.tensor_parallel_size), distributed_config=DistConfig(config.tensor_parallel_size),
enable_graph_compiling=config.enable_graph, enable_graph_compiling=config.enable_graph,
attention_backend=config.attn_backend,
) )
# Load model weights # Load model weights
...@@ -383,6 +386,7 @@ class LLM: ...@@ -383,6 +386,7 @@ class LLM:
top_p: float = 0.8, top_p: float = 0.8,
top_k: int = 1, top_k: int = 1,
enable_graph: bool = False, enable_graph: bool = False,
attn_backend: str = "default",
): ):
"""Initialize LLM. """Initialize LLM.
...@@ -401,6 +405,7 @@ class LLM: ...@@ -401,6 +405,7 @@ class LLM:
top_p: Default top-p sampling parameter. top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter. top_k: Default top-k sampling parameter.
enable_graph: Whether to enable graph compiling. enable_graph: Whether to enable graph compiling.
attn_backend: Attention backend to use ('default', 'flash-attn').
""" """
config = EngineConfig( config = EngineConfig(
model_path=model_path, model_path=model_path,
...@@ -417,6 +422,7 @@ class LLM: ...@@ -417,6 +422,7 @@ class LLM:
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
enable_graph=enable_graph, enable_graph=enable_graph,
attn_backend=attn_backend,
) )
self.engine = LLMEngine(config) self.engine = LLMEngine(config)
self.config = config self.config = config
...@@ -536,6 +542,7 @@ class AsyncLLMEngine: ...@@ -536,6 +542,7 @@ class AsyncLLMEngine:
top_p: float = 0.8, top_p: float = 0.8,
top_k: int = 1, top_k: int = 1,
enable_graph: bool = False, enable_graph: bool = False,
attn_backend: str = "default",
): ):
"""Initialize AsyncLLMEngine. """Initialize AsyncLLMEngine.
...@@ -554,6 +561,7 @@ class AsyncLLMEngine: ...@@ -554,6 +561,7 @@ class AsyncLLMEngine:
top_p: Default top-p sampling parameter. top_p: Default top-p sampling parameter.
top_k: Default top-k sampling parameter. top_k: Default top-k sampling parameter.
enable_graph: Whether to enable graph compiling. enable_graph: Whether to enable graph compiling.
attn_backend: Attention backend to use ('default', 'flash-attn').
""" """
config = EngineConfig( config = EngineConfig(
model_path=model_path, model_path=model_path,
...@@ -570,6 +578,7 @@ class AsyncLLMEngine: ...@@ -570,6 +578,7 @@ class AsyncLLMEngine:
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
enable_graph=enable_graph, enable_graph=enable_graph,
attn_backend=attn_backend,
) )
self.engine = LLMEngine(config) self.engine = LLMEngine(config)
self.config = config self.config = config
......
...@@ -108,6 +108,7 @@ class InferenceServer: ...@@ -108,6 +108,7 @@ class InferenceServer:
host: str = "0.0.0.0", host: str = "0.0.0.0",
port: int = 8000, port: int = 8000,
enable_graph: bool = False, enable_graph: bool = False,
attn_backend: str = "default",
): ):
"""Initialize inference server. """Initialize inference server.
...@@ -128,6 +129,7 @@ class InferenceServer: ...@@ -128,6 +129,7 @@ class InferenceServer:
host: Server host address. host: Server host address.
port: Server port number. port: Server port number.
enable_graph: Whether to enable graph compiling. enable_graph: Whether to enable graph compiling.
attn_backend: Attention backend to use ('default', 'flash-attn').
""" """
self.model_path = model_path self.model_path = model_path
# vLLM-like served model id: directory name of model_path # vLLM-like served model id: directory name of model_path
...@@ -147,6 +149,7 @@ class InferenceServer: ...@@ -147,6 +149,7 @@ class InferenceServer:
self.host = host self.host = host
self.port = port self.port = port
self.enable_graph = enable_graph self.enable_graph = enable_graph
self.attn_backend = attn_backend
self.engine: AsyncLLMEngine = None self.engine: AsyncLLMEngine = None
...@@ -177,6 +180,7 @@ class InferenceServer: ...@@ -177,6 +180,7 @@ class InferenceServer:
top_p=self.top_p, top_p=self.top_p,
top_k=self.top_k, top_k=self.top_k,
enable_graph=self.enable_graph, enable_graph=self.enable_graph,
attn_backend=self.attn_backend,
) )
self.engine.start() self.engine.start()
logger.info(f"Engine initialized with model at {self.model_path}") logger.info(f"Engine initialized with model at {self.model_path}")
...@@ -613,6 +617,13 @@ def parse_args(): ...@@ -613,6 +617,13 @@ def parse_args():
action="store_true", action="store_true",
help="Enable graph compiling", help="Enable graph compiling",
) )
parser.add_argument(
"--attn",
type=str,
default="default",
choices=["default", "flash-attn"],
help="Attention backend to use: 'default' or 'flash-attn'",
)
parser.add_argument( parser.add_argument(
"--log_level", "--log_level",
type=str, type=str,
...@@ -655,7 +666,7 @@ def main(): ...@@ -655,7 +666,7 @@ def main():
"Example: python infinilm.server.inference_server --nvidia --model_path=/data/shared/models/9G7B_MHA/ " "Example: python infinilm.server.inference_server --nvidia --model_path=/data/shared/models/9G7B_MHA/ "
"--max_tokens=100 --max_batch_size=32 --tp=1 --temperature=1.0 --top_p=0.8 --top_k=1" "--max_tokens=100 --max_batch_size=32 --tp=1 --temperature=1.0 --top_p=0.8 --top_k=1"
"\n" "\n"
"Optional: --enable-paged-attn --enable-graph" "Optional: --enable-paged-attn --enable-graph --attn=default"
) )
sys.exit(1) sys.exit(1)
...@@ -676,6 +687,7 @@ def main(): ...@@ -676,6 +687,7 @@ def main():
host=args.host, host=args.host,
port=args.port, port=args.port,
enable_graph=args.enable_graph, enable_graph=args.enable_graph,
attn_backend=args.attn,
) )
server.start() server.start()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment