Commit bdd33b3f authored by zhuwenwen's avatar zhuwenwen
Browse files

update fa interface and kvcache

add prepare_so_files to prepare so
parent 63053820
...@@ -91,6 +91,9 @@ python3 setup.py install (若调试,可使用python3 setup.py develop) ...@@ -91,6 +91,9 @@ python3 setup.py install (若调试,可使用python3 setup.py develop)
``` ```
若需要添加git号,设置环境变量: export ADD_GIT_VERSION=1 若需要添加git号,设置环境变量: export ADD_GIT_VERSION=1
3.跳过编译(适用于未改变csrc目录kernel并多次编译情况)
将编译后的so文件拷贝至csrc目录,并设置环境变量: export SKIP_VLLM_BUILD=1
#### 运行基础环境准备 #### 运行基础环境准备
1、使用上面基于光源pytorch2.9.0基础镜像环境 1、使用上面基于光源pytorch2.9.0基础镜像环境
......
...@@ -13,6 +13,8 @@ import sys ...@@ -13,6 +13,8 @@ import sys
import sysconfig import sysconfig
from pathlib import Path from pathlib import Path
from shutil import which from shutil import which
import tarfile
import shutil
import torch import torch
from packaging.version import Version, parse from packaging.version import Version, parse
...@@ -36,6 +38,37 @@ skip_vllm_build = False ...@@ -36,6 +38,37 @@ skip_vllm_build = False
if int(os.environ.get('SKIP_VLLM_BUILD', '0')) == 1: if int(os.environ.get('SKIP_VLLM_BUILD', '0')) == 1:
skip_vllm_build = True skip_vllm_build = True
def prepare_so_files():
source_dir = "csrc/so.tar.gz"
target_dir = "vllm"
if not os.path.exists(source_dir):
print(f"Warning: {source_dir} not found, skipping extraction")
return
print(f"Preparing C extension files from {source_dir}...")
temp_dir = "temp_so_extract"
os.makedirs(temp_dir, exist_ok=True)
try:
with tarfile.open(source_dir, "r:*") as tar:
tar.extractall(temp_dir)
for root, dirs, files in os.walk(temp_dir):
for file in files:
if file in ["_C.abi3.so", "_moe_C.abi3.so", "cumem_allocator.abi3.so"]:
src_path = os.path.join(root, file)
dst_path = os.path.join(target_dir, file)
os.makedirs(os.path.dirname(dst_path), exist_ok=True)
shutil.copy2(src_path, dst_path)
print(f"Copied {file} to {dst_path}")
finally:
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
def load_module_from_path(module_name, path): def load_module_from_path(module_name, path):
spec = importlib.util.spec_from_file_location(module_name, path) spec = importlib.util.spec_from_file_location(module_name, path)
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
...@@ -1109,6 +1142,7 @@ if _build_custom_ops(): ...@@ -1109,6 +1142,7 @@ if _build_custom_ops():
ext_modules.append(CMakeExtension(name="vllm._C")) ext_modules.append(CMakeExtension(name="vllm._C"))
if skip_vllm_build: if skip_vllm_build:
prepare_so_files()
package_data = { package_data = {
"vllm": [ "vllm": [
"py.typed", "py.typed",
......
...@@ -848,7 +848,10 @@ def unified_kv_cache_update( ...@@ -848,7 +848,10 @@ def unified_kv_cache_update(
layer_slot_mapping, layer_slot_mapping,
) )
return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype) if current_platform.is_rocm():
return torch.empty(0, device=key.device, dtype=key.dtype)
else:
return torch.empty(0, device=kv_cache.device, dtype=kv_cache.dtype)
def unified_kv_cache_update_fake( def unified_kv_cache_update_fake(
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment