Unverified Commit 169d5169 authored by RunningLeon's avatar RunningLeon Committed by GitHub
Browse files

Add more user-friendly CLI (#541)

* add

* import fire in main

* wrap to speed up fire cli

* update

* update docs

* update docs

* fix

* resolve commennts

* resolve confict and add test for cli
parent 7283781e
......@@ -2,7 +2,6 @@
from pathlib import Path
from typing import Union
import fire
import numpy as np
import torch
......@@ -120,5 +119,6 @@ def main(work_dir: str,
if __name__ == '__main__':
import fire
fire.Fire(main)
......@@ -654,4 +654,5 @@ def main(model_name: str = 'test'):
if __name__ == '__main__':
import fire
fire.Fire(main)
......@@ -51,7 +51,6 @@ import itertools
import logging
from typing import Optional
import fire
import torch
from transformers import GenerationConfig, PreTrainedModel
......@@ -205,6 +204,8 @@ def main(
def cli():
import fire
fire.Fire(main)
......
# Copyright (c) OpenMMLab. All rights reserved.
import os
import fire
from lmdeploy.serve.turbomind.chatbot import Chatbot
......@@ -66,4 +64,6 @@ def main(tritonserver_addr: str,
if __name__ == '__main__':
import fire
fire.Fire(main)
......@@ -5,7 +5,6 @@ import time
from functools import partial
from typing import Sequence
import fire
import gradio as gr
from lmdeploy.serve.async_engine import AsyncEngine
......@@ -525,7 +524,7 @@ def run(model_path_or_server: str,
server_port (int): the port of gradio server
batch_size (int): batch size for running Turbomind directly
tp (int): tensor parallel for Turbomind
restufl_api (bool): a flag for model_path_or_server
restful_api (bool): a flag for model_path_or_server
"""
if ':' in model_path_or_server:
if restful_api:
......@@ -539,4 +538,6 @@ def run(model_path_or_server: str,
if __name__ == '__main__':
import fire
fire.Fire(run)
......@@ -2,7 +2,6 @@
import json
from typing import Iterable, List
import fire
import requests
......@@ -89,4 +88,6 @@ def main(restful_api_url: str, session_id: int = 0):
if __name__ == '__main__':
import fire
fire.Fire(main)
......@@ -4,7 +4,6 @@ import time
from http import HTTPStatus
from typing import AsyncGenerator, List, Optional
import fire
import uvicorn
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
......@@ -357,4 +356,6 @@ def main(model_path: str,
if __name__ == '__main__':
import fire
fire.Fire(main)
......@@ -8,7 +8,6 @@ import shutil
import sys
from pathlib import Path
import fire
import safetensors
import torch
from safetensors.torch import load_file
......@@ -1043,4 +1042,6 @@ def main(model_name: str,
if __name__ == '__main__':
import fire
fire.Fire(main)
......@@ -4,11 +4,7 @@ import os
import os.path as osp
import random
import fire
from lmdeploy import turbomind as tm
from lmdeploy.model import MODELS
from lmdeploy.tokenizer import Tokenizer
os.environ['TM_LOG_LEVEL'] = 'ERROR'
......@@ -88,6 +84,9 @@ def main(model_path,
stream_output (bool): indicator for streaming output or not
**kwarg (dict): other arguments for initializing model's chat template
"""
from lmdeploy import turbomind as tm
from lmdeploy.tokenizer import Tokenizer
tokenizer_model_path = osp.join(model_path, 'triton_models', 'tokenizer')
tokenizer = Tokenizer(tokenizer_model_path)
tm_model = tm.TurboMind(model_path, eos_id=tokenizer.eos_token_id, tp=tp)
......@@ -157,4 +156,6 @@ def main(model_path,
if __name__ == '__main__':
import fire
fire.Fire(main)
......@@ -2,7 +2,6 @@
import os
import os.path as osp
import fire
import torch
from lmdeploy import turbomind as tm
......@@ -37,4 +36,6 @@ def main(model_path, inputs):
if __name__ == '__main__':
import fire
fire.Fire(main)
......@@ -2,8 +2,6 @@
import subprocess
import fire
def get_llama_gemm():
import os.path as osp
......@@ -30,4 +28,6 @@ def main(head_num: int = 32,
if __name__ == '__main__':
import fire
fire.Fire(main)
......@@ -121,26 +121,29 @@ def parse_requirements(fname='requirements.txt', with_version=True):
if __name__ == '__main__':
lmdeploy_package_data = ['lmdeploy/bin/llama_gemm']
setup(name='lmdeploy',
version=get_version(),
description='A toolset for compressing, deploying and serving LLM',
long_description=readme(),
long_description_content_type='text/markdown',
author='OpenMMLab',
author_email='openmmlab@gmail.com',
packages=find_packages(exclude=()),
package_data={
'lmdeploy': lmdeploy_package_data,
},
include_package_data=True,
install_requires=parse_requirements('requirements.txt'),
has_ext_modules=check_ext_modules,
classifiers=[
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Intended Audience :: Developers',
'Intended Audience :: Education',
'Intended Audience :: Science/Research',
])
setup(
name='lmdeploy',
version=get_version(),
description='A toolset for compressing, deploying and serving LLM',
long_description=readme(),
long_description_content_type='text/markdown',
author='OpenMMLab',
author_email='openmmlab@gmail.com',
packages=find_packages(exclude=()),
package_data={
'lmdeploy': lmdeploy_package_data,
},
include_package_data=True,
install_requires=parse_requirements('requirements.txt'),
has_ext_modules=check_ext_modules,
classifiers=[
'Programming Language :: Python :: 3.8',
'Programming Language :: Python :: 3.9',
'Programming Language :: Python :: 3.10',
'Programming Language :: Python :: 3.11',
'Intended Audience :: Developers',
'Intended Audience :: Education',
'Intended Audience :: Science/Research',
],
entry_points={'console_scripts': ['lmdeploy = lmdeploy.cli:run']},
)
import inspect
def compare_func(class_method, function):
"""Compare if a class method has same arguments as a function."""
argspec_cls = inspect.getfullargspec(class_method)
argspec_func = inspect.getfullargspec(function)
assert argspec_cls.args[1:] == argspec_func.args
assert argspec_cls.defaults == argspec_func.defaults
assert argspec_cls.annotations == argspec_func.annotations
def test_cli():
from lmdeploy.cli.cli import CLI
from lmdeploy.serve.turbomind.deploy import main as convert
compare_func(CLI.convert, convert)
def test_subcli_chat():
from lmdeploy.cli.chat import SubCliChat
from lmdeploy.pytorch.chat import main as run_torch_model
from lmdeploy.turbomind.chat import main as run_turbomind_model
compare_func(SubCliChat.torch, run_torch_model)
compare_func(SubCliChat.turbomind, run_turbomind_model)
def test_subcli_lite():
from lmdeploy.cli.lite import SubCliLite
from lmdeploy.lite.apis.auto_awq import auto_awq
from lmdeploy.lite.apis.calibrate import calibrate
from lmdeploy.lite.apis.kv_qparams import main as run_kv_qparams
compare_func(SubCliLite.auto_awq, auto_awq)
compare_func(SubCliLite.calibrate, calibrate)
compare_func(SubCliLite.kv_qparams, run_kv_qparams)
def test_subcli_serve():
from lmdeploy.cli.serve import SubCliServe
from lmdeploy.serve.client import main as run_triton_client
from lmdeploy.serve.gradio.app import run as run_gradio
from lmdeploy.serve.openai.api_client import main as run_api_client
from lmdeploy.serve.openai.api_server import main as run_api_server
compare_func(SubCliServe.gradio, run_gradio)
compare_func(SubCliServe.api_server, run_api_server)
compare_func(SubCliServe.api_client, run_api_client)
compare_func(SubCliServe.triton_client, run_triton_client)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment