Unverified Commit 22e8b2ca authored by Chen Xin's avatar Chen Xin Committed by GitHub
Browse files

Package 'bin/llama_gemm' to wheel (#320)

* pack llama_gemm

* update CMakeLists.txt

* remove candidate

* update MANIFEST.in
parent eaccbc0a
......@@ -44,6 +44,7 @@ htmlcov/
*build*/
!builder/
lmdeploy/lib/
lmdeploy/bin/
dist/
examples/cpp/llama/*.csv
*.npy
......
......@@ -3,5 +3,6 @@ include lmdeploy/lib/*.so
include lmdeploy/lib/*.so*
include lmdeploy/lib/*.dll
include lmdeploy/lib/*.pyd
include lmdeploy/bin/*
include lmdeploy/serve/turbomind/service_docker_up.sh
recursive-include lmdeploy/serve/turbomind/triton_models *
......@@ -5,6 +5,16 @@ import subprocess
import fire
def get_llama_gemm():
import os.path as osp
import lmdeploy
lmdeploy_dir = osp.split(lmdeploy.__file__)[0]
bin_path = osp.join(lmdeploy_dir, 'bin', 'llama_gemm')
assert osp.exists(bin_path), f'{bin_path} not exists'
return bin_path
def main(head_num: int = 32,
size_per_head: int = 128,
vocab_size: int = 32000,
......@@ -13,8 +23,9 @@ def main(head_num: int = 32,
max_batch_size: int = 64):
for bsz in range(1, max_batch_size + 1):
subprocess.call(
f'bin/llama_gemm {bsz} 1 1 {head_num} {size_per_head} {inter_size}'
f' {vocab_size} 1 {tensor_para_size} {0 if bsz == 1 else 1}',
f'{get_llama_gemm()} {bsz} 1 1 {head_num} {size_per_head}'
f' {inter_size} {vocab_size} 1 {tensor_para_size}'
f' {0 if bsz == 1 else 1}',
shell=True)
......
......@@ -120,6 +120,7 @@ def parse_requirements(fname='requirements.txt', with_version=True):
if __name__ == '__main__':
lmdeploy_package_data = ['lmdeploy/bin/llama_gemm']
setup(name='lmdeploy',
version=get_version(),
description='A toolset for compressing, deploying and serving LLM',
......@@ -128,6 +129,9 @@ if __name__ == '__main__':
author='OpenMMLab',
author_email='openmmlab@gmail.com',
packages=find_packages(exclude=()),
package_data={
'lmdeploy': lmdeploy_package_data,
},
include_package_data=True,
install_requires=parse_requirements('requirements.txt'),
has_ext_modules=check_ext_modules,
......
......@@ -48,3 +48,4 @@ endif()
add_executable(llama_gemm llama_gemm.cc)
target_link_libraries(llama_gemm PUBLIC CUDA::cudart gpt_gemm_func memory_utils cuda_utils logger)
install(TARGETS llama_gemm DESTINATION ${CMAKE_SOURCE_DIR}/lmdeploy/bin)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment