Unverified Commit 169d5169 authored by RunningLeon's avatar RunningLeon Committed by GitHub
Browse files

Add more user-friendly CLI (#541)

* add

* import fire in main

* wrap to speed up fire cli

* update

* update docs

* update docs

* fix

* resolve commennts

* resolve confict and add test for cli
parent 7283781e
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
from pathlib import Path from pathlib import Path
from typing import Union from typing import Union
import fire
import numpy as np import numpy as np
import torch import torch
...@@ -120,5 +119,6 @@ def main(work_dir: str, ...@@ -120,5 +119,6 @@ def main(work_dir: str,
if __name__ == '__main__': if __name__ == '__main__':
import fire
fire.Fire(main) fire.Fire(main)
...@@ -654,4 +654,5 @@ def main(model_name: str = 'test'): ...@@ -654,4 +654,5 @@ def main(model_name: str = 'test'):
if __name__ == '__main__': if __name__ == '__main__':
import fire import fire
fire.Fire(main) fire.Fire(main)
...@@ -51,7 +51,6 @@ import itertools ...@@ -51,7 +51,6 @@ import itertools
import logging import logging
from typing import Optional from typing import Optional
import fire
import torch import torch
from transformers import GenerationConfig, PreTrainedModel from transformers import GenerationConfig, PreTrainedModel
...@@ -205,6 +204,8 @@ def main( ...@@ -205,6 +204,8 @@ def main(
def cli(): def cli():
import fire
fire.Fire(main) fire.Fire(main)
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import os import os
import fire
from lmdeploy.serve.turbomind.chatbot import Chatbot from lmdeploy.serve.turbomind.chatbot import Chatbot
...@@ -66,4 +64,6 @@ def main(tritonserver_addr: str, ...@@ -66,4 +64,6 @@ def main(tritonserver_addr: str,
if __name__ == '__main__': if __name__ == '__main__':
import fire
fire.Fire(main) fire.Fire(main)
...@@ -5,7 +5,6 @@ import time ...@@ -5,7 +5,6 @@ import time
from functools import partial from functools import partial
from typing import Sequence from typing import Sequence
import fire
import gradio as gr import gradio as gr
from lmdeploy.serve.async_engine import AsyncEngine from lmdeploy.serve.async_engine import AsyncEngine
...@@ -525,7 +524,7 @@ def run(model_path_or_server: str, ...@@ -525,7 +524,7 @@ def run(model_path_or_server: str,
server_port (int): the port of gradio server server_port (int): the port of gradio server
batch_size (int): batch size for running Turbomind directly batch_size (int): batch size for running Turbomind directly
tp (int): tensor parallel for Turbomind tp (int): tensor parallel for Turbomind
restufl_api (bool): a flag for model_path_or_server restful_api (bool): a flag for model_path_or_server
""" """
if ':' in model_path_or_server: if ':' in model_path_or_server:
if restful_api: if restful_api:
...@@ -539,4 +538,6 @@ def run(model_path_or_server: str, ...@@ -539,4 +538,6 @@ def run(model_path_or_server: str,
if __name__ == '__main__': if __name__ == '__main__':
import fire
fire.Fire(run) fire.Fire(run)
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
import json import json
from typing import Iterable, List from typing import Iterable, List
import fire
import requests import requests
...@@ -89,4 +88,6 @@ def main(restful_api_url: str, session_id: int = 0): ...@@ -89,4 +88,6 @@ def main(restful_api_url: str, session_id: int = 0):
if __name__ == '__main__': if __name__ == '__main__':
import fire
fire.Fire(main) fire.Fire(main)
...@@ -4,7 +4,6 @@ import time ...@@ -4,7 +4,6 @@ import time
from http import HTTPStatus from http import HTTPStatus
from typing import AsyncGenerator, List, Optional from typing import AsyncGenerator, List, Optional
import fire
import uvicorn import uvicorn
from fastapi import FastAPI, Request from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
...@@ -357,4 +356,6 @@ def main(model_path: str, ...@@ -357,4 +356,6 @@ def main(model_path: str,
if __name__ == '__main__': if __name__ == '__main__':
import fire
fire.Fire(main) fire.Fire(main)
...@@ -8,7 +8,6 @@ import shutil ...@@ -8,7 +8,6 @@ import shutil
import sys import sys
from pathlib import Path from pathlib import Path
import fire
import safetensors import safetensors
import torch import torch
from safetensors.torch import load_file from safetensors.torch import load_file
...@@ -1043,4 +1042,6 @@ def main(model_name: str, ...@@ -1043,4 +1042,6 @@ def main(model_name: str,
if __name__ == '__main__': if __name__ == '__main__':
import fire
fire.Fire(main) fire.Fire(main)
...@@ -4,11 +4,7 @@ import os ...@@ -4,11 +4,7 @@ import os
import os.path as osp import os.path as osp
import random import random
import fire
from lmdeploy import turbomind as tm
from lmdeploy.model import MODELS from lmdeploy.model import MODELS
from lmdeploy.tokenizer import Tokenizer
os.environ['TM_LOG_LEVEL'] = 'ERROR' os.environ['TM_LOG_LEVEL'] = 'ERROR'
...@@ -88,6 +84,9 @@ def main(model_path, ...@@ -88,6 +84,9 @@ def main(model_path,
stream_output (bool): indicator for streaming output or not stream_output (bool): indicator for streaming output or not
**kwarg (dict): other arguments for initializing model's chat template **kwarg (dict): other arguments for initializing model's chat template
""" """
from lmdeploy import turbomind as tm
from lmdeploy.tokenizer import Tokenizer
tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer') tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer')
tokenizer = Tokenizer(tokenizer_model_path) tokenizer = Tokenizer(tokenizer_model_path)
tm_model = tm.TurboMind(model_path, eos_id=tokenizer.eos_token_id, tp=tp) tm_model = tm.TurboMind(model_path, eos_id=tokenizer.eos_token_id, tp=tp)
...@@ -157,4 +156,6 @@ def main(model_path, ...@@ -157,4 +156,6 @@ def main(model_path,
if __name__ == '__main__': if __name__ == '__main__':
import fire
fire.Fire(main) fire.Fire(main)
...@@ -2,7 +2,6 @@ ...@@ -2,7 +2,6 @@
import os import os
import os.path as osp import os.path as osp
import fire
import torch import torch
from lmdeploy import turbomind as tm from lmdeploy import turbomind as tm
...@@ -37,4 +36,6 @@ def main(model_path, inputs): ...@@ -37,4 +36,6 @@ def main(model_path, inputs):
if __name__ == '__main__': if __name__ == '__main__':
import fire
fire.Fire(main) fire.Fire(main)
...@@ -2,8 +2,6 @@ ...@@ -2,8 +2,6 @@
import subprocess import subprocess
import fire
def get_llama_gemm(): def get_llama_gemm():
import os.path as osp import os.path as osp
...@@ -30,4 +28,6 @@ def main(head_num: int = 32, ...@@ -30,4 +28,6 @@ def main(head_num: int = 32,
if __name__ == '__main__': if __name__ == '__main__':
import fire
fire.Fire(main) fire.Fire(main)
...@@ -121,26 +121,29 @@ def parse_requirements(fname='requirements.txt', with_version=True): ...@@ -121,26 +121,29 @@ def parse_requirements(fname='requirements.txt', with_version=True):
if __name__ == '__main__': if __name__ == '__main__':
lmdeploy_package_data = ['lmdeploy/bin/llama_gemm'] lmdeploy_package_data = ['lmdeploy/bin/llama_gemm']
setup(name='lmdeploy', setup(
version=get_version(), name='lmdeploy',
description='A toolset for compressing, deploying and serving LLM', version=get_version(),
long_description=readme(), description='A toolset for compressing, deploying and serving LLM',
long_description_content_type='text/markdown', long_description=readme(),
author='OpenMMLab', long_description_content_type='text/markdown',
author_email='openmmlab@gmail.com', author='OpenMMLab',
packages=find_packages(exclude=()), author_email='openmmlab@gmail.com',
package_data={ packages=find_packages(exclude=()),
'lmdeploy': lmdeploy_package_data, package_data={
}, 'lmdeploy': lmdeploy_package_data,
include_package_data=True, },
install_requires=parse_requirements('requirements.txt'), include_package_data=True,
has_ext_modules=check_ext_modules, install_requires=parse_requirements('requirements.txt'),
classifiers=[ has_ext_modules=check_ext_modules,
'Programming Language :: Python :: 3.8', classifiers=[
'Programming Language :: Python :: 3.9', 'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.10', 'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.11', 'Programming Language :: Python :: 3.10',
'Intended Audience :: Developers', 'Programming Language :: Python :: 3.11',
'Intended Audience :: Education', 'Intended Audience :: Developers',
'Intended Audience :: Science/Research', 'Intended Audience :: Education',
]) 'Intended Audience :: Science/Research',
],
entry_points={'console_scripts': ['lmdeploy = lmdeploy.cli:run']},
)
import inspect
def compare_func(class_method, function):
"""Compare if a class method has same arguments as a function."""
argspec_cls = inspect.getfullargspec(class_method)
argspec_func = inspect.getfullargspec(function)
assert argspec_cls.args[1:] == argspec_func.args
assert argspec_cls.defaults == argspec_func.defaults
assert argspec_cls.annotations == argspec_func.annotations
def test_cli():
from lmdeploy.cli.cli import CLI
from lmdeploy.serve.turbomind.deploy import main as convert
compare_func(CLI.convert, convert)
def test_subcli_chat():
from lmdeploy.cli.chat import SubCliChat
from lmdeploy.pytorch.chat import main as run_torch_model
from lmdeploy.turbomind.chat import main as run_turbomind_model
compare_func(SubCliChat.torch, run_torch_model)
compare_func(SubCliChat.turbomind, run_turbomind_model)
def test_subcli_lite():
from lmdeploy.cli.lite import SubCliLite
from lmdeploy.lite.apis.auto_awq import auto_awq
from lmdeploy.lite.apis.calibrate import calibrate
from lmdeploy.lite.apis.kv_qparams import main as run_kv_qparams
compare_func(SubCliLite.auto_awq, auto_awq)
compare_func(SubCliLite.calibrate, calibrate)
compare_func(SubCliLite.kv_qparams, run_kv_qparams)
def test_subcli_serve():
from lmdeploy.cli.serve import SubCliServe
from lmdeploy.serve.client import main as run_triton_client
from lmdeploy.serve.gradio.app import run as run_gradio
from lmdeploy.serve.openai.api_client import main as run_api_client
from lmdeploy.serve.openai.api_server import main as run_api_server
compare_func(SubCliServe.gradio, run_gradio)
compare_func(SubCliServe.api_server, run_api_server)
compare_func(SubCliServe.api_client, run_api_client)
compare_func(SubCliServe.triton_client, run_triton_client)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment