Commit ca7366d2 authored by Azure's avatar Azure
Browse files

Merge remote-tracking branch 'upstream/develop-0.2.2' into support-fp8

parents 581a524f cdb6f896
name: DockerHub CI
on:
release:
types: [published]
# push:
# branches:
# - main
env:
DOCKERHUB_REPO: ${{ secrets.DOCKERHUB_USERNAME }}/ktransformers
jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v2
- name: Run tests
run: |
if [ -f docker-compose.test.yml ]; then
docker-compose --file docker-compose.test.yml build
docker-compose --file docker-compose.test.yml run sut
else
docker build . --file Dockerfile
fi
docker_task:
needs: test
name: ${{ matrix.instruct}}
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
include:
# for amd64
- {instruct: "FANCY", platform: "linux/amd64"}
- {instruct: "AVX512", platform: "linux/amd64"}
- {instruct: "AVX2", platform: "linux/amd64"}
- {instruct: "NATIVE", platform: "linux/amd64"}
# for arm64
- {instruct: "NATIVE", platform: "linux/arm64"}
steps:
- name: Move Docker data directory
run: |
sudo systemctl stop docker
sudo mkdir -p /mnt/docker
sudo rsync -avz /var/lib/docker/ /mnt/docker
sudo rm -rf /var/lib/docker
sudo ln -s /mnt/docker /var/lib/docker
sudo systemctl start docker
-
name: Set up QEMU
uses: docker/setup-qemu-action@v3
-
name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
-
name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKERHUB_USERNAME }}
password: ${{ secrets.DOCKERHUB_TOKEN }}
-
name: Build and push for amd64
if: matrix.platform == 'linux/amd64'
uses: docker/build-push-action@v6
with:
push: true
platforms: |
linux/amd64
tags: |
${{ env.DOCKERHUB_REPO }}:latest-${{ matrix.instruct }}
${{ env.DOCKERHUB_REPO }}:${{ github.event.release.tag_name }}-${{ matrix.instruct }}
build-args: |
CPU_INSTRUCT=${{ matrix.instruct }}
-
name: Build and push for arm64
if: matrix.platform == 'linux/arm64'
uses: docker/build-push-action@v6
with:
push: true
platforms: |
linux/arm64
tags: |
${{ env.DOCKERHUB_REPO }}:latest-${{ matrix.instruct }}
${{ env.DOCKERHUB_REPO }}:${{ github.event.release.tag_name }}-${{ matrix.instruct }}
build-args: |
CPU_INSTRUCT=${{ matrix.instruct }}
\ No newline at end of file
...@@ -28,3 +28,4 @@ ktransformers/tests/chat_txt.txt ...@@ -28,3 +28,4 @@ ktransformers/tests/chat_txt.txt
mmlu_result_q4km.json mmlu_result_q4km.json
mmlu_result_q4km.log mmlu_result_q4km.log
ktransformers/tests/mmlu_result_silicon.log ktransformers/tests/mmlu_result_silicon.log
ktransformers/ktransformers_ext/cuda_musa/
...@@ -11,6 +11,7 @@ EOF ...@@ -11,6 +11,7 @@ EOF
FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-devel as compile_server FROM pytorch/pytorch:2.3.1-cuda12.1-cudnn8-devel as compile_server
ARG CPU_INSTRUCT=NATIVE
WORKDIR /workspace WORKDIR /workspace
ENV CUDA_HOME /usr/local/cuda ENV CUDA_HOME /usr/local/cuda
COPY --from=web_compile /home/ktransformers /workspace/ktransformers COPY --from=web_compile /home/ktransformers /workspace/ktransformers
...@@ -28,8 +29,9 @@ git submodule init && ...@@ -28,8 +29,9 @@ git submodule init &&
git submodule update && git submodule update &&
pip install ninja pyproject numpy cpufeature && pip install ninja pyproject numpy cpufeature &&
pip install flash-attn && pip install flash-attn &&
CPU_INSTRUCT=NATIVE KTRANSFORMERS_FORCE_BUILD=TRUE TORCH_CUDA_ARCH_LIST="8.0;8.6;8.7;8.9;9.0+PTX" pip install . --no-build-isolation --verbose && CPU_INSTRUCT=${CPU_INSTRUCT} KTRANSFORMERS_FORCE_BUILD=TRUE TORCH_CUDA_ARCH_LIST="8.0;8.6;8.7;8.9;9.0+PTX" pip install . --no-build-isolation --verbose &&
pip cache purge pip cache purge &&
cp /usr/lib/x86_64-linux-gnu/libstdc++.so.6 /opt/conda/lib/
EOF EOF
ENTRYPOINT ["tail", "-f", "/dev/null"] ENTRYPOINT ["tail", "-f", "/dev/null"]
\ No newline at end of file
...@@ -103,7 +103,7 @@ Getting started with KTransformers is simple! Follow the steps below to set up a ...@@ -103,7 +103,7 @@ Getting started with KTransformers is simple! Follow the steps below to set up a
### 📥 Installation ### 📥 Installation
To install KTransformers, follow the official [Installation Guide](https://kvcache-ai.github.io/ktransformers/). To install KTransformers, follow the official [Installation Guide](https://kvcache-ai.github.io/ktransformers/en/install.html).
<h2 id="tutorial">📃 Brief Injection Tutorial</h2> <h2 id="tutorial">📃 Brief Injection Tutorial</h2>
......
...@@ -81,7 +81,7 @@ Some preparation: ...@@ -81,7 +81,7 @@ Some preparation:
git submodule update git submodule update
``` ```
- [Optional] If you want to run with website, please [compile the website](./doc/en/api/server/website.md) before execute ```bash install.sh``` - [Optional] If you want to run with website, please [compile the website](./api/server/website.md) before execute ```bash install.sh```
- For Linux - For Linux
- For simple install: - For simple install:
...@@ -103,7 +103,7 @@ Some preparation: ...@@ -103,7 +103,7 @@ Some preparation:
install.bat install.bat
``` ```
* If you are developer, you can make use of the makefile to compile and format the code. <br> the detailed usage of makefile is [here](./doc/en/makefile_usage.md) * If you are developer, you can make use of the makefile to compile and format the code. <br> the detailed usage of makefile is [here](./makefile_usage.md)
<h3>Local Chat</h3> <h3>Local Chat</h3>
We provide a simple command-line local chat Python script that you can run for testing. We provide a simple command-line local chat Python script that you can run for testing.
......
...@@ -30,6 +30,8 @@ if (NOT MSVC) ...@@ -30,6 +30,8 @@ if (NOT MSVC)
option(LLAMA_F16C "llama: enable F16C" OFF) option(LLAMA_F16C "llama: enable F16C" OFF)
endif() endif()
option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI" OFF) option(LLAMA_AVX512_FANCY_SIMD "llama: enable AVX512-VL, AVX512-BW, AVX512-DQ, AVX512-VNNI" OFF)
option(KTRANSFORMERS_USE_CUDA "ktransformers: use CUDA" OFF)
option(KTRANSFORMERS_USE_MUSA "ktransformers: use MUSA" OFF)
# Architecture specific # Architecture specific
# TODO: probably these flags need to be tweaked on some architectures # TODO: probably these flags need to be tweaked on some architectures
...@@ -208,8 +210,31 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party) ...@@ -208,8 +210,31 @@ include_directories(${CMAKE_CURRENT_SOURCE_DIR}/../../third_party)
if (WIN32) if (WIN32)
include_directories("$ENV{CUDA_PATH}/include") include_directories("$ENV{CUDA_PATH}/include")
elseif (UNIX) elseif (UNIX)
find_package(CUDA REQUIRED) if (KTRANSFORMERS_USE_CUDA)
include_directories("${CUDA_INCLUDE_DIRS}") find_package(CUDA REQUIRED)
include_directories("${CUDA_INCLUDE_DIRS}")
add_compile_definitions(KTRANSFORMERS_USE_CUDA=1)
endif()
if (KTRANSFORMERS_USE_MUSA)
if (NOT EXISTS $ENV{MUSA_PATH})
if (NOT EXISTS /opt/musa)
set(MUSA_PATH /usr/local/musa)
else()
set(MUSA_PATH /opt/musa)
endif()
else()
set(MUSA_PATH $ENV{MUSA_PATH})
endif()
list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake")
find_package(MUSAToolkit)
if (MUSAToolkit_FOUND)
message(STATUS "MUSA Toolkit found")
add_compile_definitions(KTRANSFORMERS_USE_MUSA=1)
endif()
endif()
endif() endif()
aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1) aux_source_directory(${CMAKE_CURRENT_SOURCE_DIR} SOURCE_DIR1)
...@@ -225,10 +250,15 @@ target_link_libraries(${PROJECT_NAME} PRIVATE llama) ...@@ -225,10 +250,15 @@ target_link_libraries(${PROJECT_NAME} PRIVATE llama)
if(WIN32) if(WIN32)
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib")#CUDA::cudart target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_PATH}/lib/x64/cudart.lib")#CUDA::cudart
elseif(UNIX) elseif(UNIX)
if(NOT DEFINED ENV{CUDA_HOME} OR "$ENV{CUDA_HOME}" STREQUAL "") if(KTRANSFORMERS_USE_CUDA)
set(ENV{CUDA_HOME} "/usr/local/cuda") if(NOT DEFINED ENV{CUDA_HOME} OR "$ENV{CUDA_HOME}" STREQUAL "")
set(ENV{CUDA_HOME} "/usr/local/cuda")
endif()
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so")
endif()
if(KTRANSFORMERS_USE_MUSA)
target_link_libraries(${PROJECT_NAME} PRIVATE MUSA::musart)
endif() endif()
target_link_libraries(${PROJECT_NAME} PRIVATE "$ENV{CUDA_HOME}/lib64/libcudart.so")
endif() endif()
# Define the USE_NUMA option # Define the USE_NUMA option
......
...@@ -17,7 +17,11 @@ ...@@ -17,7 +17,11 @@
#include <queue> #include <queue>
#include <thread> #include <thread>
#include <vector> #include <vector>
#include "cuda_runtime.h" #ifdef KTRANSFORMERS_USE_CUDA
#include "vendors/cuda.h"
#elif KTRANSFORMERS_USE_MUSA
#include "vendors/musa.h"
#endif
#include "backend.h" #include "backend.h"
#include "task_queue.h" #include "task_queue.h"
......
## TODO
This directory can be removed after updating the version of `llama.cpp`.
\ No newline at end of file
#pragma once
#include <cuda_runtime.h>
\ No newline at end of file
#pragma once
#include <musa_runtime.h>
#define cudaLaunchHostFunc musaLaunchHostFunc
#define cudaStream_t musaStream_t
#define cudaHostFn_t musaHostFn_t
\ No newline at end of file
/** /**
* @Description : * @Description :
* @Author : Azure-Tang * @Author : Azure-Tang, Boxin Zhang
* @Date : 2024-07-25 13:38:30 * @Date : 2024-07-25 13:38:30
* @Version : 1.0.0 * @Version : 0.2.2
* @LastEditors : kkk1nak0
* @LastEditTime : 2024-08-12 03:05:04
* @Copyright (c) 2024 by KVCache.AI, All Rights Reserved. * @Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
**/ **/
#include "custom_gguf/ops.h" #include "custom_gguf/ops.h"
#ifdef KTRANSFORMERS_USE_CUDA
#include "gptq_marlin/ops.h" #include "gptq_marlin/ops.h"
#endif
// Python bindings // Python bindings
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
...@@ -19,22 +19,46 @@ ...@@ -19,22 +19,46 @@
// namespace py = pybind11; // namespace py = pybind11;
PYBIND11_MODULE(KTransformersOps, m) { PYBIND11_MODULE(KTransformersOps, m) {
m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); m.def("dequantize_q8_0", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.", return dequantize_q8_0((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
py::arg("data"), py::arg("blk_size"), py::arg("device")); }, "Function to dequantize q8_0 data.",
m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.", py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.", m.def("dequantize_q6_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
py::arg("data"), py::arg("blk_size"), py::arg("device")); return dequantize_q6_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.", }, "Function to dequantize q6_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device")); m.def("dequantize_q5_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.", return dequantize_q5_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
py::arg("data"), py::arg("blk_size"), py::arg("device")); }, "Function to dequantize q5_k data.",
m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.", py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"),
py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"), m.def("dequantize_q4_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full")); return dequantize_q4_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
}, "Function to dequantize q4_k data.",
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q3_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
return dequantize_q3_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
}, "Function to dequantize q3_k data.",
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_q2_k", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
return dequantize_q2_k((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
}, "Function to dequantize q2_k data.",
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
m.def("dequantize_iq4_xs", [](const intptr_t data, int num_bytes, int blk_size, const int ele_per_blk, torch::Device device, torch::Dtype target_dtype) {
return dequantize_iq4_xs((int8_t*)data, num_bytes, blk_size, ele_per_blk, device, target_dtype);
}, "Function to dequantize iq4_xs data.",
py::arg("data"), py::arg("num_bytes"), py::arg("blk_size"), py::arg("ele_per_blk"), py::arg("device"), py::arg("target_dtype"));
#ifdef KTRANSFORMERS_USE_CUDA
m.def("gptq_marlin_gemm", &gptq_marlin_gemm, "Function to perform GEMM using Marlin quantization.",
py::arg("a"), py::arg("b_q_weight"), py::arg("b_scales"), py::arg("g_idx"),
py::arg("perm"), py::arg("workspace"), py::arg("num_bits"), py::arg("size_m"),
py::arg("size_n"), py::arg("size_k"), py::arg("is_k_full"));
#endif
} }
#include "ops.h"
// Python bindings
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <torch/library.h>
#include <torch/extension.h>
#include <torch/torch.h>
// namespace py = pybind11;
int test(){
return 5;
}
torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device);
torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device);
PYBIND11_MODULE(cudaops, m) {
m.def("dequantize_q8_0", &dequantize_q8_0, "Function to dequantize q8_0 data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q6_k", &dequantize_q6_k, "Function to dequantize q6_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q5_k", &dequantize_q5_k, "Function to dequantize q5_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q4_k", &dequantize_q4_k, "Function to dequantize q4_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q3_k", &dequantize_q3_k, "Function to dequantize q3_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_q2_k", &dequantize_q2_k, "Function to dequantize q2_k data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("dequantize_iq4_xs", &dequantize_iq4_xs, "Function to dequantize iq4_xs data.",
py::arg("data"), py::arg("blk_size"), py::arg("device"));
m.def("test", &test, "Function to test.");
}
...@@ -13,10 +13,10 @@ ...@@ -13,10 +13,10 @@
#include <torch/extension.h> #include <torch/extension.h>
#include <torch/torch.h> #include <torch/torch.h>
torch::Tensor dequantize_q8_0(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q8_0(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype);
torch::Tensor dequantize_q6_k(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q6_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype);
torch::Tensor dequantize_q5_k(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q5_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype);
torch::Tensor dequantize_q4_k(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q4_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype);
torch::Tensor dequantize_q3_k(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q3_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype);
torch::Tensor dequantize_q2_k(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_q2_k(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype);
torch::Tensor dequantize_iq4_xs(torch::Tensor data, int blk_size, torch::Device device); torch::Tensor dequantize_iq4_xs(const int8_t* data, const int num_bytes, const int blk_size, const int ele_per_blk, const torch::Device device, const torch::ScalarType target_dtype);
import os
import sys
sys.path.insert(0,"/home/zbx/ktransformers")
from ktransformers.util.custom_gguf import GGUFLoader
import torch
gguf_loader_1 = GGUFLoader("/mnt/data/model/DeepseekV3-q4km-gguf")
gguf_loader_2 = GGUFLoader("/mnt/data/chenht/model/gguf_for_ktransformers/DeepSeek-V3-bf16/")
torch.set_default_dtype(torch.bfloat16)
tensor_1 = gguf_loader_1.load_gguf_tensor("blk.0.attn_kv_a_mqa.weight", "cuda")
tensor_2 = gguf_loader_2.load_gguf_tensor("blk.0.attn_kv_a_mqa.weight", "cuda")
print(tensor_1[0, -64:])
print(tensor_2[0, -64:])
\ No newline at end of file
...@@ -90,7 +90,7 @@ def marlin_quantize( ...@@ -90,7 +90,7 @@ def marlin_quantize(
assert group_size <= size_k assert group_size <= size_k
# Quantize (and apply act_order if provided) # Quantize (and apply act_order if provided)
w_ref, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size, q_w, s, g_idx, rand_perm = quantize_weights(w, num_bits, group_size,
act_order) act_order)
# For act_order, sort the "weights" and "g_idx" so that group ids are # For act_order, sort the "weights" and "g_idx" so that group ids are
...@@ -107,7 +107,7 @@ def marlin_quantize( ...@@ -107,7 +107,7 @@ def marlin_quantize(
marlin_scale_perm_single[num_bits]) marlin_scale_perm_single[num_bits])
# Create result # Create result
res_list = [w_ref, marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm] res_list = [marlin_q_w, marlin_s, g_idx, sort_indices, rand_perm]
for i in range(len(res_list)): for i in range(len(res_list)):
res_list[i] = res_list[i].to(w.device) res_list[i] = res_list[i].to(w.device)
......
...@@ -11,8 +11,7 @@ def get_pack_factor(num_bits): ...@@ -11,8 +11,7 @@ def get_pack_factor(num_bits):
return 32 // num_bits return 32 // num_bits
def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): def permute_rows(q_w: torch.Tensor, group_size: int):
assert q_w.shape == w_ref.shape
orig_device = q_w.device orig_device = q_w.device
k_size, _ = q_w.shape k_size, _ = q_w.shape
...@@ -26,10 +25,8 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int): ...@@ -26,10 +25,8 @@ def permute_rows(q_w: torch.Tensor, w_ref: torch.Tensor, group_size: int):
g_idx = g_idx[rand_perm].contiguous() g_idx = g_idx[rand_perm].contiguous()
q_w = q_w[rand_perm, :].contiguous() q_w = q_w[rand_perm, :].contiguous()
w_ref = w_ref[rand_perm, :].contiguous()
return ( return (
w_ref.to(device=orig_device),
q_w.to(device=orig_device), q_w.to(device=orig_device),
g_idx.to(device=orig_device), g_idx.to(device=orig_device),
rand_perm.to(device=orig_device), rand_perm.to(device=orig_device),
...@@ -69,9 +66,6 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int, ...@@ -69,9 +66,6 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
q_w += half_q_val q_w += half_q_val
q_w = torch.clamp(q_w, 0, max_q_val) q_w = torch.clamp(q_w, 0, max_q_val)
# Compute ref (dequantized)
w_ref = (q_w - half_q_val).half() * s
# Restore original shapes # Restore original shapes
if group_size < size_k: if group_size < size_k:
...@@ -82,7 +76,6 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int, ...@@ -82,7 +76,6 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
return w return w
q_w = reshape_w(q_w) q_w = reshape_w(q_w)
w_ref = reshape_w(w_ref)
s = s.reshape((-1, size_n)).contiguous() s = s.reshape((-1, size_n)).contiguous()
...@@ -95,10 +88,9 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int, ...@@ -95,10 +88,9 @@ def quantize_weights(w: torch.Tensor, num_bits: int, group_size: int,
), "For act_order, groupsize = {} must be less than size_k = {}".format( ), "For act_order, groupsize = {} must be less than size_k = {}".format(
group_size, size_k) group_size, size_k)
w_ref, q_w, g_idx, rand_perm = permute_rows(q_w, w_ref, group_size) q_w, g_idx, rand_perm = permute_rows(q_w, group_size)
return ( return (
w_ref.to(device=orig_device),
q_w.to(device=orig_device), q_w.to(device=orig_device),
s.to(device=orig_device), s.to(device=orig_device),
g_idx.to(device=orig_device), g_idx.to(device=orig_device),
......
...@@ -168,10 +168,7 @@ def local_chat( ...@@ -168,10 +168,7 @@ def local_chat(
if mode == 'long_context': if mode == 'long_context':
assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \ assert Config().long_context_config['max_seq_len'] > input_tensor.shape[1] + max_new_tokens, \
"please change max_seq_len in ~/.ktransformers/config.yaml" "please change max_seq_len in ~/.ktransformers/config.yaml"
torch.set_default_dtype(
torch.bfloat16
) # TODO: Remove this, replace dtype using config
if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM") and flashinfer_enabled: if system != "Windows" and (config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM") and flashinfer_enabled:
generated = prefill_and_generate( generated = prefill_and_generate(
model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think, model, tokenizer, input_tensor.cuda(), max_new_tokens, use_cuda_graph, mode = mode, force_think = force_think,
......
...@@ -1742,8 +1742,7 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel): ...@@ -1742,8 +1742,7 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
) )
hidden_states = outputs[0] hidden_states = outputs[0]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states[:,-1:,:]).float()
logits = logits[:,-1,:].unsqueeze(0).float()
loss = None loss = None
if labels is not None: if labels is not None:
......
...@@ -1699,7 +1699,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel): ...@@ -1699,7 +1699,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
) )
hidden_states = outputs[0] hidden_states = outputs[0]
logits = self.lm_head(hidden_states.to(self.lm_head.weight.device)) logits = self.lm_head(hidden_states[:,-1:,:])
logits = logits.float() logits = logits.float()
loss = None loss = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment