Unverified Commit 77a34c28 authored by UnicornChan's avatar UnicornChan Committed by GitHub
Browse files

Merge pull request #36 from kvcache-ai/develop-0.1.2

Release v0.1.2
parents 44f57270 395cd3e7
......@@ -6,6 +6,7 @@ from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.models.custom_cache import StaticCache
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from ktransformers.local_chat import custom_models, default_optimize_rules
from ktransformers.util.utils import get_device
class KTransformersThreadContext(TransformersThreadContext):
......@@ -48,8 +49,11 @@ class KTransformersInterface(TransformersInterface):
def decode_one_tokens(self):
if not hasattr(self, "cuda_graph_runner"):
device_map = self.model.gguf_loader.tensor_device_map
torch_device = get_device('blk.0.self_attn', device_map)
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
self.cuda_graph_runner = CUDAGraphRunner()
self.cuda_graph_runner.capture(self.model, self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position, self.cache, return_dict=False, use_cache=True)
self.cuda_graph_runner.capture(self.model, self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position, self.cache, main_device=torch_device, return_dict=False, use_cache=True)
if hasattr(self, "cuda_graph_runner"):
logits = self.cuda_graph_runner(self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position)
......
This diff is collapsed.
......@@ -7,11 +7,11 @@ import pycuda.autoinit
import pycuda.driver as cuda
from pycuda.compiler import SourceModule
import numpy as np
from ktransformers.operators.linear import KTransformerLinear, QuantizedLinearMarlin
from ktransformers.operators.experts import KTransformersMLPExpert, MLPExpertsTorch
from ktransformers.operators.linear import KTransformersLinear, KLinearMarlin
from ktransformers.operators.experts import KTransformersExperts, KExpertsTorch
from ktransformers.util.custom_gguf import GGUFLoader, dequantize_q4_k_gpu, dequantize_q4_k
import torch
import CudaOps
import KTransformersOps
torch.set_default_dtype(torch.bfloat16)
import time
from transformers import (
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment