Commit 18c42e67 authored by chenxl's avatar chenxl
Browse files

Initial commit

parents
from enum import Enum
from typing import Dict, List, Optional, Union, ForwardRef
from pydantic import BaseModel, Field, model_validator
from ktransformers.server.models.assistants.runs import Run
from ktransformers.server.schemas.base import TODO, Metadata, MetadataField, ObjectWithCreatedTime
from ktransformers.server.schemas.assistants.threads import ThreadCreate
from ktransformers.server.schemas.assistants.tool import Tool, ToolResource
from ktransformers.server.utils.sql_utils import SQLUtil
class ToolCall(BaseModel):
id: str
type: str
function: TODO
class SubmitToolOutputs(BaseModel):
tool_calls: List[ToolCall]
class RequiredAction(BaseModel):
type: str
submit_tool_outputs: TODO
class LastError(BaseModel):
code: str
message: str
class IncompleteDetails(BaseModel):
reason: str
class Usage(BaseModel):
completion_tokens: int
prompt_tokens: int
total_tokens: int
class TruncationStrategy(BaseModel):
type: str = "auto"
last_message: Optional[int]
class ToolChoiceType(Enum):
none = "none"
auto = "auto"
required = "required"
class RunBase(BaseModel):
class Status(Enum):
created = "created" # only stream event will have this created status
queued = "queued"
in_progress = "in_progress"
requires_action = "requires_action"
cancelling = "cancelling"
cancelled = "cancelled"
failed = "failed"
completed = "completed"
expired = "expired"
thread_id: str
assistant_id: str
status: Status = Status.queued
required_action: Optional[RequiredAction] = Field(None)
last_error: Optional[LastError] = Field(None)
expires_at: Optional[int]= Field(None)
started_at: Optional[int] = Field(None)
cancelled_at: Optional[int] = Field(None)
failed_at: Optional[int] = Field(None)
completed_at: Optional[int] = Field(None)
incomplete_details: Optional[IncompleteDetails] = Field(None)
model: Optional[str] = Field(None)
instructions: Optional[str] = Field(None)
tools: Optional[List[Tool]] = Field([])
meta_data: Metadata = MetadataField
@model_validator(mode='before')
@classmethod
def convert_meta_data(cls,values):
if 'meta_data' in values:
values['metadata'] = values['meta_data']
return values
def set_compute_save(self,save:int):
self.meta_data['compute_save'] = str(save)
usage: Optional[Usage] = Field(None)
temperature: Optional[float] = Field(None)
top_p: Optional[float]= Field(None)
max_propmp_tokens: Optional[int]= Field(None)
truncation_strategy: Optional[TruncationStrategy]= Field(None)
tool_choice: Optional[Union[ToolChoiceType, dict]]= Field(None)
response_format: Union[str, Dict[str, str]] = "auto"
RunStreamResponse = ForwardRef('RunStreamResponse')
class RunObject(RunBase, ObjectWithCreatedTime):
def stream_response_with_event(self,event:RunBase.Status)->RunStreamResponse:
match event:
case RunBase.Status.created:
self.status = RunBase.Status.queued
case _:
self.status = event
return RunStreamResponse(run=self, event=event)
def sync_db(self):
# raise NotImplementedError # should be replaced in crud
sql_utils = SQLUtil()
db_run = Run(
**self.model_dump(mode='json'),
)
with sql_utils.get_db() as db:
sql_utils.db_merge_commit(db, db_run)
def create_message_creation_step(self):
raise NotImplementedError # should be replaced
class RunStreamResponse(BaseModel):
run: RunObject
event: RunObject.Status
def to_stream_reply(self):
return f"event: thread.run.{self.event.value}\ndata: {self.run.model_dump_json()}\n\n"
class RunCreate(BaseModel):
assistant_id: str
model: Optional[str] = Field(default=None)
instructions: Optional[str] = Field(default=None)
# TODO: Add this
# additional_instructions: Optional[str]
# additional_messages: Optional[List[MessageCore]]
tools: List[Tool] = Field(default=[])
meta_data: Metadata = MetadataField
@model_validator(mode='before')
@classmethod
def convert_meta_data(cls,values):
if 'meta_data' in values:
values['metadata'] = values['meta_data']
return values
temperature: Optional[float] = Field(default=None)
top_p: Optional[float] = Field(default=None)
stream: Optional[bool] = Field(default=None)
max_propmp_tokens: Optional[int] = Field(default=None)
# TODO: Add this
# max_completion_tokens: Optional[int]
truncation_strategy: Optional[TruncationStrategy] = Field(default=None)
tool_choice: Optional[Union[ToolChoiceType, dict]] = Field(default=None)
response_format: Union[str, Dict[str, str]] = Field(default="auto")
class RunThreadCreate(BaseModel):
assistant_id: str
thread: Optional[ThreadCreate]
model: Optional[str]
instructions: Optional[str]
tools: List[Tool]
tool_resources: List[ToolResource]
meta_data: Metadata = MetadataField
@model_validator(mode='before')
@classmethod
def convert_meta_data(cls,values):
if 'meta_data' in values:
values['metadata'] = values['meta_data']
return values
temperature: Optional[float]
top_p: Optional[float]
stream: Optional[bool]
max_propmp_tokens: Optional[int]
# TODO: Add this
# max_completion_tokens: Optional[int]
truncation_strategy: TruncationStrategy
tool_choice: Union[ToolChoiceType, dict]
response_format: Union[str, Dict[str, str]] = "auto"
class RunModify(BaseModel):
meta_data: Metadata = MetadataField
@model_validator(mode='before')
@classmethod
def convert_meta_data(cls,values):
if 'meta_data' in values:
values['metadata'] = values['meta_data']
return values
class ToolOutput(BaseModel):
tool_call_id: Optional[str]
output: Optional[str]
class RunSubmit(BaseModel):
tool_outputs: List[ToolOutput]
stream: Optional[bool]
import asyncio
from typing import AsyncIterable, List, Union
from fastapi import Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from ktransformers.server.schemas.assistants.runs import RunStreamResponse
from ktransformers.server.schemas.endpoints.chat import ChatCompletionChunk
from ktransformers.server.config.log import logger
from ktransformers.server.schemas.base import Object
from ktransformers.server.schemas.assistants.messages import ContentType, ImageFileObject, ImageUrlObject, MessageObject, Text, TextObject
class TextObjectWithIndex(TextObject):
index: int
class ImageFileObjectWithIndex(ImageFileObject):
index: int
class ImageUrlObjectWithIndex(ImageUrlObject):
index: int
ContentWithIndex = Union[TextObjectWithIndex,
ImageFileObjectWithIndex, ImageUrlObjectWithIndex]
class MessageDeltaImpl(BaseModel):
# role: Optional[str]
content: List[ContentWithIndex]
class MessageDelta(Object):
delta: MessageDeltaImpl
def to_stream_reply(self):
return f"event: thread.message.delta\ndata: {self.model_dump_json()}\n\n"
def text_delta(index: int, text: str):
return MessageDeltaImpl(content=[TextObjectWithIndex(index=index, type=ContentType.text, text=Text(value=text))])
def append_message_delta(self: MessageObject, text: str):
if len(self.content) == 0:
self.content.append(TextObject(type=ContentType.text,
text=Text(value=''), delta_index=0))
text_object: TextObject = self.content[0]
if text_object.filter_append(text):
return MessageDelta(id=self.id, object="thread.message.delta", delta=text_delta(text_object.delta_index, text))
else:
return None
MessageObject.append_message_delta = append_message_delta
class RunStepDeltaImpl(BaseModel):
pass
class RunStepDelta(Object):
delta: RunStepDeltaImpl
def to_stream_reply(self):
return f"event: thread.run.step.delta\ndata: {self.model_dump_json()}\n\n"
class Done():
def to_stream_reply(self):
return f"event: done\ndata: [DONE]\n\n"
async def check_client_link(request: Request, async_events: AsyncIterable):
async for event in async_events:
if await request.is_disconnected():
break
yield event
async def add_done(async_events: AsyncIterable):
async for event in async_events:
yield event
yield Done()
async def to_stream_reply(async_events: AsyncIterable):
async for event in async_events:
if isinstance(event, str):
yield event
else:
yield event.to_stream_reply()
async def filter_api_event(async_events: AsyncIterable):
async for event in async_events:
if isinstance(event, MessageDelta) or isinstance(event, RunStepDelta) or isinstance(event, RunStreamResponse) or isinstance(event, Done):
yield event
async def filter_chat_chunk(async_events: AsyncIterable):
async for event in async_events:
if isinstance(event, ChatCompletionChunk):
yield event
async def filter_by_types(async_events: AsyncIterable, types: List):
async for event in async_events:
for type in types:
if isinstance(event, type):
yield event
continue
def api_stream_response(request: Request, async_events: AsyncIterable):
return StreamingResponse(check_client_link(request, to_stream_reply(add_done(filter_api_event(async_events)))), media_type="text/event-stream")
def chat_stream_response(request: Request, async_events: AsyncIterable):
return StreamingResponse(check_client_link(request, to_stream_reply(add_done(filter_chat_chunk(async_events)))), media_type="text/event-stream")
def stream_response(request: Request, async_events: AsyncIterable):
return StreamingResponse(check_client_link(request, to_stream_reply(add_done(async_events))), media_type="text/event-stream")
def check_link_response(request: Request, async_events: AsyncIterable):
return StreamingResponse(check_client_link(request, async_events), media_type="text/event-stream")
def wrap_async_generator_into_queue(async_events: AsyncIterable) -> asyncio.Queue:
queue = asyncio.Queue()
async def inner():
# logger.debug('run inner')
async for event in async_events:
# logger.debug(f'put: {event}')
await queue.put(event)
await asyncio.sleep(0)
# logger.debug(f'put: None')
await queue.put(None)
asyncio.create_task(inner())
return queue
async def unwrap_async_queue(queue: asyncio.Queue) -> AsyncIterable:
while True:
events = [await queue.get()]
events.extend([queue.get_nowait() for _ in range(queue.qsize())])
logger.debug(f'getting {len(events)} events')
for event in events:
if event is None:
break
yield event
async def unwrap_async_queue_slow(queue: asyncio.Queue) -> AsyncIterable:
while True:
event = await queue.get()
# logger.debug(f'unwrap_async_queue {event}')
if event is None:
break
yield event
from enum import Enum
from typing import List
from typing_extensions import Self
from pydantic import BaseModel, Field, model_validator
from ktransformers.server.schemas.base import Metadata, MetadataField, ObjectWithCreatedTime
from ktransformers.server.schemas.assistants.tool import ToolResource
from ktransformers.server.schemas.assistants.messages import MessageCore
class ThreadBase(BaseModel):
meta_data: Metadata = MetadataField
@model_validator(mode='before')
@classmethod
def convert_meta_data(cls,values):
if 'meta_data' in values:
values['metadata'] = values['meta_data']
return values
tool_resources: List[ToolResource] = Field([], max_length=128)
class ThreadObject(ThreadBase, ObjectWithCreatedTime):
is_related_threads:bool = Field(False,exclude=True)
@model_validator(mode='after')
def check_is_related_threads(self)->Self:
# logger.debug(f'check thread {self.id} is related thread? by {self}')
if 'assistant_id' in self.meta_data:
self.is_related_threads = True
return self
class StreamEvent(Enum):
created = 'created'
def to_stream_reply(self,event:StreamEvent):
return f"event: thread.{event.value}\ndata: {self.model_dump_json()}\n\n"
class ThreadCreate(ThreadBase):
messages: List[MessageCore] = Field(default=[])
class ThreadModify(ThreadBase):
pass
# other than OpenAI API
from enum import Enum
from typing import List, Optional, Union
from pydantic import BaseModel, Field
from ktransformers.server.schemas.base import ObjectID
class ToolType(str, Enum):
CODE_INTERPRETER = "code_interpreter"
FILE_SEARCH = "file_search"
RELATED_THREADS = "related_threads"
FUNCTION = "function"
class ToolBase(BaseModel):
type: ToolType
class CodeInterpreter(ToolBase):
pass
class FileSearch(ToolBase):
pass
class RelatedThreads(ToolBase):
pass
class FuntionTool(ToolBase):
description: str
name: str
parameters: List[str]
Tool = Union[CodeInterpreter, FileSearch, RelatedThreads, FuntionTool]
class CodeInterpreterResource(BaseModel):
file_ids: Optional[List[str]] = Field(default_factory=list, max_length=20)
class FileSearchResource(BaseModel):
vector_store_ids: Optional[List[str]] = Field(default_factory=list, max_length=1)
vector_stores: Optional[List[str]] = Field(default_factory=list, max_length=1)
class RelatedThreadsResource(BaseModel):
thread_ids: List[ObjectID] = Field(default=[])
ToolResource = Union[CodeInterpreterResource,FileSearchResource,RelatedThreadsResource]
from enum import Enum
from typing import Dict
import sqlalchemy
from pydantic import BaseModel, ConfigDict, Field
TODO = BaseModel
ObjectID = str
class Object(BaseModel):
id: ObjectID
object: str
model_config = ConfigDict(from_attributes=True)
# Pydantic Base Models
class ObjectWithCreatedTime(Object):
created_at: int
class Order(str, Enum):
ASC = "asc"
DESC = "desc"
def to_sqlalchemy_order(self):
match self:
case Order.ASC:
return sqlalchemy.asc
case Order.DESC:
return sqlalchemy.desc
Metadata = Dict[str, str]
MetadataField: Metadata = Field({},max_length=16, alias="metadata")
class DeleteResponse(Object):
deleted: bool = True
class OperationResponse(BaseModel):
operation: str
status: str
from typing import Optional
from pydantic import BaseModel
from .assistants.assistants import AssistantObject
from .assistants.threads import ThreadObject
from .assistants.messages import MessageObject
class ThreadPreview(BaseModel):
assistant: Optional[AssistantObject] = None
thread: ThreadObject
first_message: Optional[MessageObject] = None
from typing import List, Optional
from enum import Enum
from pydantic import BaseModel
from ktransformers.server.schemas.base import Object
class Role(Enum):
system = 'system'
user = 'user'
assistant = 'assistant'
tool = 'tool'
function = 'function'
class Message(BaseModel):
content: str
role:Role
name: Optional[str] = None
def to_tokenizer_message(self):
return {'content':self.content,'role':self.role.value}
class ChatCompletionCreate(BaseModel):
messages: List[Message]
model : str
stream : bool = False
def get_tokenizer_messages(self):
return [m.to_tokenizer_message() for m in self.messages]
class FinishReason(Enum):
stop = 'stop'
length = 'length'
class Choice(BaseModel):
index: int
message: Message
logprobs: Optional[str] = None
finish_reason: FinishReason = None
class DeltaChoice(BaseModel):
index: int
delta: Message
logprobs: Optional[str] = None
finish_reason: FinishReason = None
class Usage(BaseModel):
completion_tokens:int
prompt_tokens:int
total_tokens:int
class ChatCompletionBase(Object):
created:int
model:str = 'not implmented'
system_fingerprint:str = 'not implmented'
usage: Optional[Usage] = None
class ChatCompletionObject(ChatCompletionBase):
choices:List[Choice] = []
def append_token(self,token:str):
if len(self.choices) == 0:
self.choices.append(Choice(index=0,message=Message(content='',role=Role.assistant)))
self.choices[0].message.content += token
class ChatCompletionChunk(ChatCompletionBase):
choices:List[DeltaChoice] = []
def set_token(self,token:str):
self.choices = [
DeltaChoice(index=0,delta=Message(content=token,role=Role.assistant))
]
def to_stream_reply(self):
return f"data:{self.model_dump_json()}\n\n"
from typing import List, Optional
from enum import Enum
from pydantic import BaseModel
from ..base import Object
class CompletionCreate(BaseModel):
model: str
prompt: str | List[str]
stream: bool = False
def get_tokenizer_messages(self):
if isinstance(self.prompt,List):
self.get_tokenizer_messages('\n'.join(self.prompt))
return [{'content':self.prompt,'role':'user'}]
class FinishReason(Enum):
stop = 'stop'
length = 'length'
class Choice(BaseModel):
index: int
text: str
logprobs: Optional[str] = None
finish_reason: FinishReason = None
class CompletionObject(Object):
created:int
choices: List[Choice] = []
model:str = 'not implmented'
system_fingerprint:str = 'not implmented'
usage: Optional[str] = None
def set_token(self,token:str):
if len(self.choices)==0:
self.choices.append(Choice(index=0,text=''))
self.choices[0].text = token
def append_token(self,token:str):
if len(self.choices)==0:
self.choices.append(Choice(index=0,text=''))
self.choices[0].text += token
def to_stream_reply(self):
return f"data:{self.model_dump_json()}\n\n"
#!/usr/bin/env python
# coding=utf-8
'''
Description :
Author : qiyuxinlin
Date : 2024-07-25 11:50:16
Version : 1.0.0
LastEditors : qiyuxinlin
LastEditTime : 2024-07-25 12:54:48
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
from ktransformers.server.config.config import Config
from ktransformers.server.backend.args import ConfigArgs
from ktransformers.server.backend.context_manager import ThreadContextManager
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface
from ktransformers.server.backend.interfaces.transformers import TransformersInterface
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface
def create_interface(config: Config, default_args: ConfigArgs):
if config.backend_type=='transformers':
from ktransformers.server.backend.interfaces.transformers import TransformersInterface as BackendInterface
elif config.backend_type == 'exllamav2':
from ktransformers.server.backend.interfaces.exllamav2 import ExllamaInterface as BackendInterface
elif config.backend_type == 'ktransformers':
from ktransformers.server.backend.interfaces.ktransformers import KTransformersInterface as BackendInterface
else:
raise NotImplementedError(f'{config.backend_type} not implemented')
GlobalInterface.interface = BackendInterface(default_args)
GlobalContextManager.context_manager = ThreadContextManager(GlobalInterface.interface)
class GlobalContextManager:
context_manager: ThreadContextManager
class GlobalInterface:
interface: TransformersInterface | KTransformersInterface | ExllamaInterface
def get_thread_context_manager() -> ThreadContextManager:
return GlobalContextManager.context_manager
def get_interface() -> TransformersInterface | KTransformersInterface | ExllamaInterface:
return GlobalInterface.interface
\ No newline at end of file
import time
def format_time(seconds):
units = [
("hours", 3600),
("minutes", 60),
("seconds", 1),
("milliseconds", 1e-3),
("microseconds", 1e-6),
]
for unit_name, unit_value in units:
if seconds >= unit_value:
time_value = seconds / unit_value
return f"{time_value:.2f} {unit_name}"
return "0 seconds" # Handle case for 0 seconds
class Profiler:
def __init__(self):
self.timers = {}
self.counters = {}
def create_timer(self, name):
self.timers[name] = {
"start_time": None,
"elapsed_time": 0,
"running": False,
}
def start_timer(self, name):
if name not in self.timers:
raise ValueError(f"Timer '{name}' does not exist.")
if self.timers[name]["running"]:
raise ValueError(f"Timer '{name}' is already running.")
self.timers[name]["start_time"] = time.time()
self.timers[name]["running"] = True
def pause_timer(self, name):
if name not in self.timers:
raise ValueError(f"Timer '{name}' does not exist.")
if not self.timers[name]["running"]:
raise ValueError(f"Timer '{name}' is not running.")
self.timers[name]["elapsed_time"] += time.time() - self.timers[name]["start_time"]
self.timers[name]["running"] = False
def get_timer_sec(self, name):
if name not in self.timers:
raise ValueError(f"Timer '{name}' does not exist.")
if self.timers[name]["running"]:
current_time = self.timers[name]["elapsed_time"] + (time.time() - self.timers[name]["start_time"])
else:
current_time = self.timers[name]["elapsed_time"]
return current_time
def get_all_timers(self):
all_timers = {}
for name in self.timers:
all_timers[name] = self.get_timer_sec(name)
return all_timers
def report_timer_string(self, name):
return f"{name} elapsed time: {format_time(self.get_timer_sec(name))}"
def create_and_start_timer(self, name):
self.create_timer(name)
self.start_timer(name)
# Counter
def inc(self,key:str,delta:int=1):
self.counters[key] = self.counters.get(key,0) + delta
def set_counter(self,key:str,to=0):
self.counters[key] = to
def get_counter(self,key:str):
return self.counters.get(key,0)
#!/usr/bin/env python
# coding=utf-8
'''
Description :
Author : chenxl
Date : 2024-06-12 09:12:58
Version : 1.0.0
LastEditors : chenxl
LastEditTime : 2024-07-27 01:56:04
'''
from urllib.parse import urlparse
import os
from contextlib import contextmanager
from sqlalchemy import create_engine
from sqlalchemy.orm import Session, sessionmaker, declarative_base
from ktransformers.server.config.config import Config
from ktransformers.server.config.singleton import Singleton
from ktransformers.server.config.log import logger
from ktransformers.server.exceptions import db_exception
Base = declarative_base()
class SQLUtil(metaclass=Singleton):
"""
database connections init and management
"""
sqlalchemy_engine = None
session_local = None
def __init__(self) -> None:
self.cfg: Config = Config()
if not self.sqlalchemy_engine:
SQLUtil.init_engine(self.cfg)
@contextmanager
def get_db(self):
"""
After you finish using the session, it's crucial to close it.
"""
if not SQLUtil.sqlalchemy_engine:
SQLUtil.init_engine(self.cfg)
session = self.session_local() # type: ignore pylint: disable=not-callable
try:
yield session
finally:
session.close()
@staticmethod
def init_engine(cfg: Config):
"""
initial engine and session maker Factory
"""
pool_size = cfg.db_pool_size
if SQLUtil.sqlalchemy_engine is None:
if cfg.db_type == "sqllite":
db_url = SQLUtil.create_sqllite_url(cfg)
else:
logger.error("Unsupported database type %s", cfg.db_type)
exit(-1)
SQLUtil.sqlalchemy_engine = create_engine(
db_url, connect_args={"check_same_thread": False}, pool_size=pool_size)
SQLUtil.session_local = sessionmaker(
autocommit=False, autoflush=False, bind=SQLUtil.sqlalchemy_engine)
@staticmethod
def create_sqllite_url(cfg):
"""
create and validate SQLLite url
"""
path: str = cfg.db_host
database: str = cfg.db_database
absolute_path: str = os.path.join(path, database)
url = 'sqlite:///' + absolute_path
try:
result = urlparse(url)
if all([result.scheme, result.path, result.scheme == 'sqlite']):
return url
else:
logger.error("invalid sqllite url: %s", url)
exit(-1)
except ValueError:
logger.error("invalid sqllite url: %s", url)
exit(-1)
def db_add_commit_refresh(self, session: Session, what):
"""
add data to database
"""
try:
session.add(what)
session.commit()
session.refresh(what)
except Exception as e:
logger.exception("db commit error with data %s", str(what.__dict__))
ex = db_exception()
ex.detail = str(e)
session.rollback()
raise ex from e
def db_merge_commit(self, session: Session, what):
try:
session.merge(what)
session.commit()
except Exception as e:
ex = db_exception()
ex.detail = str(e)
logger.exception("db merge commit error with data %s", str(what.__dict__))
session.rollback()
raise ex from e
def db_update_commit_refresh(self, session: Session, existing, what):
what = what.model_dump(mode="json")
try:
for key in what.keys():
if what[key] is not None: # 检查b中的字段是否为None
setattr(existing, key, what[key]) # 更新a的字段
session.commit()
session.refresh(existing)
except Exception as e:
ex = db_exception()
ex.detail = str(e)
logger.exception("db update commit refresh error with data %s", str(what.__dict__))
session.rollback()
raise ex from e
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"
# add path
import sys
current_path = os.path.abspath(os.path.dirname(__file__))
sys.path.append(current_path+"/../..")
import pycuda.autoinit
import pycuda.driver as cuda
from pycuda.compiler import SourceModule
import numpy as np
# from ktransformers.operators.linear import KTransformerLinear, QuantizedLinearMarlin
# from ktransformers.operators.experts import KTransformersMLPExpert, MLPExpertsTorch
from ktransformers.util.custom_gguf import GGUFLoader
import torch
import KTransformersOps
torch.set_default_dtype(torch.bfloat16)
import time
from transformers import (
AutoConfig,
)
gguf_config = GGUFLoader("/data/Qwen2-57B-A14B-Instruct-GGUF/q4_k_m")
model_name = "/data/Qwen2-57B-A14B-Instruct"
key = "blk.0."
target = "ffn_down_exps.weight"
t1 = time.time()
q_weight_cpu = gguf_config.load_gguf_tensor(key+target, "cpu")
# q_weight_cpu = torch.from_numpy(q_weight_cpu)
t2 = time.time()
q_weight_gpu = gguf_config.load_gguf_tensor(key+target, "cuda")
t3 = time.time()
print()
allclose = torch.allclose(q_weight_cpu, q_weight_gpu.cpu().to(torch.float32), atol=1e-6)
print(f"Q6k {key+target}")
print("load gguf tensor from cpu cost: ", t2-t1)
print("load gguf tensor from gpu cost: ", t3-t2)
print("allclose: ", allclose)
key = "blk.1."
target = "ffn_up_shexp.weight"
t1 = time.time()
q_weight_cpu = gguf_config.load_gguf_tensor(key+target, "cpu")
# q_weight_cpu = torch.from_numpy(q_weight_cpu)
t2 = time.time()
q_weight_gpu = gguf_config.load_gguf_tensor(key+target, "cuda")
t3 = time.time()
print()
allclose = torch.allclose(q_weight_cpu, q_weight_gpu.cpu(), atol=1e-6)
print(f"Q4k {key+target}")
print("load gguf tensor from cpu cost: ", t2-t1)
print("load gguf tensor from gpu cost: ", t3-t2)
print("allclose: ", allclose)
import os
os.environ["CUDA_VISIBLE_DEVICES"]="1"
# add path
import sys
sys.path.append("../..")
import pycuda.autoinit
import pycuda.driver as cuda
from pycuda.compiler import SourceModule
import numpy as np
from ktransformers.operators.linear import KTransformerLinear, QuantizedLinearMarlin
from ktransformers.operators.experts import KTransformersMLPExpert, MLPExpertsTorch
from ktransformers.util.custom_gguf import GGUFLoader, dequantize_q4_k_gpu, dequantize_q4_k
import torch
import CudaOps
torch.set_default_dtype(torch.bfloat16)
import time
from transformers import (
AutoConfig,
)
gguf_config = GGUFLoader("/data/Qwen2-57B-A14B-Instruct-GGUF/q4_k_m")
model_name = "/data/Qwen2-57B-A14B-Instruct"
key = "blk.0."
target = "ffn_up_exps.weight"
data = gguf_config.get_mmap_tensor(key + target)
_, factors, offsets, qs1, qs2= dequantize_q4_k(data)
factors_cpu = torch.from_numpy(factors)
offsets_cpu = torch.from_numpy(offsets)
qs1_cpu = torch.from_numpy(qs1)
qs2_cpu = torch.from_numpy(qs2)
_, factors, offsets, qs1, qs2 = dequantize_q4_k_gpu(data)
print(torch.allclose(factors.cpu(), factors_cpu))
print(torch.allclose(offsets.cpu(), offsets_cpu))
print(torch.allclose(qs1.cpu(), qs1_cpu))
print(torch.allclose(qs2.cpu(), qs2_cpu))
\ No newline at end of file
'''
Description :
Author : Boxin Zhang
Version : 0.1.0
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
import torch
from typing import Dict
class CUDAGraphRunner:
def __init__(self):
self.graph = None
self.input_buffers: Dict[str, torch.Tensor] = {}
self.output_buffers: Dict[str, torch.Tensor] = {}
def capture(
self,
model,
cur_token,
position_ids,
cache_position,
past_key_values,
**kwargs,
) -> None:
assert self.graph is None
# Capture the graph.
torch.cuda.synchronize()
self.graph = torch.cuda.CUDAGraph()
#self.graph.enable_debug_mode()
self.model = model
inputs_embeds = model.model.embed_tokens(cur_token.to("cpu")).to("cuda")
with torch.cuda.graph(self.graph):
logits=model(inputs_embeds=inputs_embeds,
position_ids=position_ids,
cache_position=cache_position,
past_key_values=past_key_values,
**kwargs)[0]
past_key_values.change_seq_length(-1)
torch.cuda.synchronize()
#self.graph.debug_dump("cuda_graph_hooked.dot")
# Save the input and output buffers.
self.input_buffers = {
"inputs_embeds": inputs_embeds,
"position_ids": position_ids,
"cache_position": cache_position,
}
self.output_buffers = {"logits": logits}
return
def forward(
self,
cur_token,
position_ids,
cache_position,
) -> torch.Tensor:
# Copy the input tensors to the input buffers.
inputs_embeds = self.model.model.embed_tokens(cur_token.to("cpu"))
self.input_buffers["inputs_embeds"].copy_(inputs_embeds)
self.input_buffers["position_ids"].copy_(position_ids)
self.input_buffers["cache_position"].copy_(cache_position)
# Run the graph.
#print("begin replay")
#time.sleep(1)
self.graph.replay()
torch.cuda.synchronize()
# Return the output tensor.
return self.output_buffers["logits"]
def __call__(self, *args, **kwargs):
return self.forward(*args, **kwargs)
#!/usr/bin/env python
# coding=utf-8
'''
Description :
Author : Azure-Tang, Boxin Zhang, chenht2022
Date : 2024-07-26 08:48:54
Version : 1.0.0
LastEditors : Azure
LastEditTime : 2024-07-26 09:28:25
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
# copied from llama.cpp/gguf-py/gguf/constants.py to satisfy dependence of gguf
# GGUF specification
# https://github.com/ggerganov/ggml/blob/master/docs/gguf.md
import struct
import warnings
import numpy as np
import numpy.typing as npt
from typing import Sequence
import os
from enum import IntEnum
import torch
import KTransformersOps
class GGMLQuantizationType(IntEnum):
F32 = 0
F16 = 1
Q4_0 = 2
Q4_1 = 3
Q5_0 = 6
Q5_1 = 7
Q8_0 = 8
Q8_1 = 9
Q2_K = 10
Q3_K = 11
Q4_K = 12
Q5_K = 13
Q6_K = 14
Q8_K = 15
IQ2_XXS = 16
IQ2_XS = 17
IQ3_XXS = 18
IQ1_S = 19
IQ4_NL = 20
IQ3_S = 21
IQ2_S = 22
IQ4_XS = 23
I8 = 24
I16 = 25
I32 = 26
I64 = 27
F64 = 28
IQ1_M = 29
BF16 = 30
QK_K = 256
GGML_QUANT_SIZES: dict[GGMLQuantizationType, tuple[int, int]] = {
GGMLQuantizationType.F32: (1, 4),
GGMLQuantizationType.F16: (1, 2),
GGMLQuantizationType.Q4_0: (32, 2 + 16),
GGMLQuantizationType.Q4_1: (32, 2 + 2 + 16),
GGMLQuantizationType.Q5_0: (32, 2 + 4 + 16),
GGMLQuantizationType.Q5_1: (32, 2 + 2 + 4 + 16),
GGMLQuantizationType.Q8_0: (32, 2 + 32),
GGMLQuantizationType.Q8_1: (32, 4 + 4 + 32),
GGMLQuantizationType.Q2_K: (256, 2 + 2 + QK_K // 16 + QK_K // 4),
GGMLQuantizationType.Q3_K: (256, 2 + QK_K // 4 + QK_K // 8 + 12),
GGMLQuantizationType.Q4_K: (256, 2 + 2 + QK_K // 2 + 12),
GGMLQuantizationType.Q5_K: (256, 2 + 2 + QK_K // 2 + QK_K // 8 + 12),
GGMLQuantizationType.Q6_K: (256, 2 + QK_K // 2 + QK_K // 4 + QK_K // 16),
GGMLQuantizationType.Q8_K: (256, 4 + QK_K + QK_K // 8),
GGMLQuantizationType.IQ2_XXS: (256, 2 + QK_K // 4),
GGMLQuantizationType.IQ2_XS: (256, 2 + QK_K // 4 + QK_K // 32),
GGMLQuantizationType.IQ3_XXS: (256, 2 + QK_K // 4 + QK_K // 8),
GGMLQuantizationType.IQ1_S: (256, 2 + QK_K // 8 + QK_K // 16),
GGMLQuantizationType.IQ4_NL: (32, 2 + 16),
GGMLQuantizationType.IQ3_S: (256, 2 + QK_K // 4 + QK_K // 8 + QK_K // 32 + 4),
GGMLQuantizationType.IQ2_S: (256, 2 + QK_K // 4 + QK_K // 16),
GGMLQuantizationType.IQ4_XS: (256, 2 + 2 + QK_K // 2 + QK_K // 64),
GGMLQuantizationType.I8: (1, 1),
GGMLQuantizationType.I16: (1, 2),
GGMLQuantizationType.I32: (1, 4),
GGMLQuantizationType.I64: (1, 8),
GGMLQuantizationType.F64: (1, 8),
GGMLQuantizationType.IQ1_M: (256, QK_K // 8 + QK_K // 16 + QK_K // 32),
GGMLQuantizationType.BF16: (1, 2),
}
# copied from llama.cpp/gguf-py/gguf/quants.py to avoid dependence of gguf
def quant_shape_to_byte_shape(shape: Sequence[int], quant_type: GGMLQuantizationType):
block_size, type_size = GGML_QUANT_SIZES[quant_type]
if shape[-1] % block_size != 0:
raise ValueError(f"Quantized tensor row size ({shape[-1]}) is not a multiple of {quant_type.name} block size ({block_size})")
return (*shape[:-1], shape[-1] // block_size * type_size)
GGML_TYPES = {
"F32": 0,
"F16": 1,
"Q8_0": 8,
"Q2_K": 10,
"Q3_K": 11,
"Q4_K": 12,
"Q5_K": 13,
"Q6_K": 14,
}
GGML_NAMES = {ggml_type: name for name, ggml_type in GGML_TYPES.items()}
GGML_BLOCK_SIZES = {
"F32": 4,
"F16": 2,
"Q8_0": 2 + 32,
"Q2_K": 256 // 16 + 256 // 4 + 2 + 2,
"Q3_K": 256 // 8 + 256 // 4 + 12 + 2,
"Q4_K": 2 + 2 + 12 + 256 // 2,
"Q5_K": 2 + 2 + 12 + 256 // 8 + 256 // 2,
"Q6_K": 256 // 2 + 256 // 4 + 256 // 16 + 2,
}
GGML_ELEMENTS_PER_BLOCK = {
"F32": 1,
"F16": 1,
"Q8_0": 32,
"Q2_K": 256,
"Q3_K": 256,
"Q4_K": 256,
"Q5_K": 256,
"Q6_K": 256,
}
# DATA_TYPES = {
# "uint32": 4,
# "int32": 5,
# "float32": 6,
# "string": 8,
# "array": 9,
# "uint64": 10,
# }
DATA_TYPES = {
"uint8": 0,
"int8": 1,
"uint16": 2,
"int16": 3,
"uint32": 4,
"int32": 5,
"float32": 6,
"bool": 7,
"string": 8,
"array": 9,
"uint64": 10,
"int64": 11,
"float64": 12,
}
class GGUFLoader:
tensor_info: dict
gguf_path: str
tensor_file_map: dict # {tensor_name: tensor_file_path}
gguf_file_meta: dict
def __init__(self, gguf_path: str):
# Check dir exist
if not os.path.exists(gguf_path):
raise FileNotFoundError(f"GGUF dir not found: {gguf_path}")
self.tensor_info = {}
self.gguf_path = gguf_path
self.tensor_file_map = {}
self.file_data_map = {}
self.gguf_file_meta = {}
# Walk through all the .gguf files in the directory
for root, dirs, files in os.walk(gguf_path):
for file in files:
if file.endswith(".gguf"):
file_name = os.path.join(root, file)
with open(file_name, "rb") as f:
self.load_gguf(f)
if file_name not in self.file_data_map:
self.file_data_map[file_name] = np.memmap(file_name, mode = 'r')
def load_gguf(self, f):
f.seek(0)
assert f.read(4) == b'GGUF'
values = struct.unpack("<IQQ", f.read(4+8+8))
version, n_tensors, n_kv = values
if version != 3:
warnings.warn(f"Version {version} has never been tested, might not work")
info = {}
for _ in range(n_kv):
name = read_value(f, DATA_TYPES["string"])
data_type = struct.unpack("<I", f.read(4))[0]
info[name] = read_value(f, data_type)
tensor_info = {}
for _ in range(n_tensors):
name = read_value(f, DATA_TYPES["string"])
shape_len = read_value(f, DATA_TYPES["uint32"])
shape = [read_value(f, DATA_TYPES["uint64"]) for _ in range(shape_len)]
ggml_type = read_value(f, DATA_TYPES["uint32"])
bad_offset = read_value(f, DATA_TYPES["uint64"])
n_elems = int(np.prod(shape))
block_size, type_size = GGML_QUANT_SIZES[ggml_type]
n_bytes = n_elems * type_size // block_size
np_dims = tuple(reversed(shape))
item_type: npt.DTypeLike
if ggml_type == GGMLQuantizationType.F16:
item_count = n_elems
item_type = np.float16
elif ggml_type == GGMLQuantizationType.F32:
item_count = n_elems
item_type = np.float32
elif ggml_type == GGMLQuantizationType.F64:
item_count = n_elems
item_type = np.float64
elif ggml_type == GGMLQuantizationType.I8:
item_count = n_elems
item_type = np.int8
elif ggml_type == GGMLQuantizationType.I16:
item_count = n_elems
item_type = np.int16
elif ggml_type == GGMLQuantizationType.I32:
item_count = n_elems
item_type = np.int32
elif ggml_type == GGMLQuantizationType.I64:
item_count = n_elems
item_type = np.int64
else:
item_count = n_bytes
item_type = np.uint8
np_dims = quant_shape_to_byte_shape(np_dims, ggml_type)
tensor_info[name] = {
"ggml_type": ggml_type,
"shape": shape,
"bad_offset": bad_offset,
"item_type": item_type,
"item_count": item_count,
"np_dims": np_dims
}
start = f.tell()
# Alignment is 32 by default.
# https://github.com/ggerganov/ggml/blob/e1daebbf9d38d510ba456c4d50b4500a73ac2b14/docs/gguf.md?plain=1#L253
alignment = info.get("general.alignment", 32)
# Inconveniently, the offset defined in gguf files is relative to the
# end of the header and is unaligned.
# We need to compute the absolute file offset ourselves instead.
for t in tensor_info.values():
offset = start + t["bad_offset"]
offset += (alignment - offset % alignment) % alignment
t["offset"] = offset
for name in tensor_info:
self.tensor_file_map[name] = f.name
self.tensor_info.update(tensor_info)
self.gguf_file_meta.update(info)
def get_mmap_tensor(self, name):
t = self.tensor_info[name]
mmap_data = self.file_data_map[ self.tensor_file_map[name] ]
offset = t["offset"]
item_type = t["item_type"]
item_count = t["item_count"]
itemsize = int(np.empty([], dtype = item_type).itemsize)
return mmap_data[offset : offset + itemsize * item_count]
def load_gguf_tensor(self, name: str, device:str = "cpu")->torch.Tensor:
t = self.tensor_info[name]
shape = t["shape"]
ggml_type = t["ggml_type"]
if ggml_type not in GGML_NAMES:
raise NotImplementedError(f"ggml_type {ggml_type} not implemented")
ggml_name = GGML_NAMES[ggml_type]
data = self.get_mmap_tensor(name)
if "cuda" in device.lower():
values = GGML_DEQUANTIZE_GPU[ggml_name](data, device)
else:
values = GGML_DEQUANTIZE[ggml_name](data)
values = torch.from_numpy(values)
return values.view(shape[::-1])
def read_value(f, data_type):
if data_type == DATA_TYPES["string"]:
length = struct.unpack("<Q", f.read(8))[0]
return f.read(length).decode("utf-8")
elif data_type == DATA_TYPES["bool"]:
return bool(struct.unpack("<?", f.read(1))[0])
elif data_type == DATA_TYPES["uint8"]:
return struct.unpack("<B", f.read(1))[0]
elif data_type == DATA_TYPES["int8"]:
return struct.unpack("<b", f.read(1))[0]
elif data_type == DATA_TYPES["uint16"]:
return struct.unpack("<H", f.read(2))[0]
elif data_type == DATA_TYPES["int16"]:
return struct.unpack("<h", f.read(2))[0]
elif data_type == DATA_TYPES["uint32"]:
return struct.unpack("<I", f.read(4))[0]
elif data_type == DATA_TYPES["int32"]:
return struct.unpack("<i", f.read(4))[0]
elif data_type == DATA_TYPES["float32"]:
return struct.unpack("<f", f.read(4))[0]
elif data_type == DATA_TYPES["uint64"]:
return struct.unpack("<Q", f.read(8))[0]
elif data_type == DATA_TYPES["int64"]:
return struct.unpack("<q", f.read(8))[0]
elif data_type == DATA_TYPES["float64"]:
return struct.unpack("<d", f.read(8))[0]
elif data_type == DATA_TYPES["array"]:
elem_type, count = struct.unpack("<IQ", f.read(4 + 8))
return [read_value(f, elem_type) for _ in range(count)]
else:
raise NotImplementedError(f"Data type {data_type} not implemented")
def dequantize_q2_k(data):
# C implementation
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1547
# C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L74
block_size = GGML_BLOCK_SIZES["Q2_K"]
num_blocks = len(data) // block_size
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
dmin = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32)
d = data_f16[:, -2].reshape(num_blocks, 1, 1).astype(np.float32)
scales = data_u8[:, :16].reshape(num_blocks, 16, 1)
qs = data_u8[:, 16:80].reshape(num_blocks, 64)
tmp = np.stack([
qs[:, 00:16] >> 0,
qs[:, 16:32] >> 0,
qs[:, 00:16] >> 2,
qs[:, 16:32] >> 2,
qs[:, 00:16] >> 4,
qs[:, 16:32] >> 4,
qs[:, 00:16] >> 6,
qs[:, 16:32] >> 6,
qs[:, 32:48] >> 0,
qs[:, 48:64] >> 0,
qs[:, 32:48] >> 2,
qs[:, 48:64] >> 2,
qs[:, 32:48] >> 4,
qs[:, 48:64] >> 4,
qs[:, 32:48] >> 6,
qs[:, 48:64] >> 6,
], axis=1)
return d * (scales & 15) * (tmp & 3) - dmin * (scales >> 4)
def dequantize_q2_k_gpu(data):
pass
def dequantize_q3_k(data):
# C implementation
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1723C32-L1723C42
# C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L95
block_size = GGML_BLOCK_SIZES["Q3_K"]
num_blocks = len(data) // block_size
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
d = data_f16[:, -1].reshape(num_blocks, 1, 1).astype(np.float32)
bits = np.unpackbits(data_u8[:, :32].reshape(num_blocks, 32, 1), axis=-1, bitorder="little")
bits = 4 ^ (bits << 2)
qs = data_u8[:, 32:32 + 64].astype(np.int16)
a, b, c = data_u8[:, 96: 96 + 12].reshape(num_blocks, 3, 4).transpose(1, 0, 2)
scales = np.zeros((num_blocks, 4, 4), dtype=np.uint8)
scales[:, 0] = (a & 15) | ((c & 3) << 4)
scales[:, 1] = (b & 15) | (((c >> 2) & 3) << 4)
scales[:, 2] = (a >> 4) | (((c >> 4) & 3) << 4)
scales[:, 3] = (b >> 4) | ((c >> 6) << 4)
scales = scales.reshape(num_blocks, 16, 1).astype(np.int16)
return d * (scales - 32) * np.stack([
(((qs[:, 00:16] >> 0) & 3) - bits[:, :16, 0]),
(((qs[:, 16:32] >> 0) & 3) - bits[:, 16:, 0]),
(((qs[:, 00:16] >> 2) & 3) - bits[:, :16, 1]),
(((qs[:, 16:32] >> 2) & 3) - bits[:, 16:, 1]),
(((qs[:, 00:16] >> 4) & 3) - bits[:, :16, 2]),
(((qs[:, 16:32] >> 4) & 3) - bits[:, 16:, 2]),
(((qs[:, 00:16] >> 6) & 3) - bits[:, :16, 3]),
(((qs[:, 16:32] >> 6) & 3) - bits[:, 16:, 3]),
(((qs[:, 32:48] >> 0) & 3) - bits[:, :16, 4]),
(((qs[:, 48:64] >> 0) & 3) - bits[:, 16:, 4]),
(((qs[:, 32:48] >> 2) & 3) - bits[:, :16, 5]),
(((qs[:, 48:64] >> 2) & 3) - bits[:, 16:, 5]),
(((qs[:, 32:48] >> 4) & 3) - bits[:, :16, 6]),
(((qs[:, 48:64] >> 4) & 3) - bits[:, 16:, 6]),
(((qs[:, 32:48] >> 6) & 3) - bits[:, :16, 7]),
(((qs[:, 48:64] >> 6) & 3) - bits[:, 16:, 7])
], axis=1)
def dequantize_q3_k_gpu(data):
pass
def dequantize_q4_k(data):
# C implementation
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L1929
# C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L116
block_size = GGML_BLOCK_SIZES["Q4_K"]
num_blocks = len(data) // block_size
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
# Casting to float32 because float16 is very slow on CPU
scale_factors = data_f16[:, 0].reshape(num_blocks, 1, 1).astype(np.float32)
scale_offsets = data_f16[:, 1].reshape(num_blocks, 1, 1).astype(np.float32)
qs1 = data_u8[:, 4:16].reshape(num_blocks, 12, 1)
qs2 = data_u8[:, 16:].reshape(num_blocks, 4, 32)
# Dequantize scales and offsets (6 bits and 4 + 2 bits)
factors = scale_factors * np.concatenate([qs1[:, 0:4] & 0b111111, (qs1[:, 8:] & 15) | ((qs1[:, 0:4] >> 6) << 4)], axis=1)
offsets = scale_offsets * np.concatenate([qs1[:, 4:8] & 0b111111, (qs1[:, 8:] >> 4) | ((qs1[:, 4:8] >> 6) << 4)], axis=1)
# Interleave low and high quantized bits
qs2 = np.stack([qs2 & 0xf, qs2 >> 4], axis=2).reshape(num_blocks, 8, 32)
# Dequantize final weights using scales and offsets
return factors * qs2 - offsets
def dequantize_q4_k_gpu(data, device:str ="cuda"):
data = np.frombuffer(data, dtype=data.dtype)
device = torch.device(device)
# TODO: this and from_numpy in other functions will cause a warning saying that numpy is not writable,
# the best way to fix this is transfer ptr to KTransformersOps instead of Tensor.
data = torch.from_numpy(data)
return KTransformersOps.dequantize_q4_k(data, 144, device)
def dequantize_q5_k(data):
# C implementation
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L2129
# C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L138
block_size = GGML_BLOCK_SIZES["Q5_K"]
num_blocks = len(data) // block_size
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
d = data_f16[:, 0].reshape(num_blocks, 1).astype(np.float32)
dmin = data_f16[:, 1].reshape(num_blocks, 1).astype(np.float32)
scales = data_u8[:, 4:16].reshape(num_blocks, 12, 1)
qh = data_u8[:, 16: 16 + 32].reshape(num_blocks, 32, 1)
qs = data_u8[:, 48: 48 + 128].reshape(num_blocks, 4, 32)
bits = np.unpackbits(qh, axis=-1, bitorder="little")
qs_hi_4 = qs >> 4
qs_lo_4 = qs & 15
scales_lo_6 = scales[:, :8] & 63
scales_hi_6 = scales[:, :8] >> 6
scales_lo_4 = scales[:, 8:] & 15
scales_hi_4 = scales[:, 8:] >> 4
m1 = dmin * scales_lo_6[:, 4]
m2 = dmin * scales_lo_6[:, 5]
m3 = dmin * scales_lo_6[:, 6]
m4 = dmin * scales_lo_6[:, 7]
m5 = dmin * (scales_hi_4[:, 0] | (scales_hi_6[:, 4] << 4))
m6 = dmin * (scales_hi_4[:, 1] | (scales_hi_6[:, 5] << 4))
m7 = dmin * (scales_hi_4[:, 2] | (scales_hi_6[:, 6] << 4))
m8 = dmin * (scales_hi_4[:, 3] | (scales_hi_6[:, 7] << 4))
d1 = d * scales_lo_6[:, 0]
d2 = d * scales_lo_6[:, 1]
d3 = d * scales_lo_6[:, 2]
d4 = d * scales_lo_6[:, 3]
d5 = d * (scales_lo_4[:, 0] | (scales_hi_6[:, 0] << 4))
d6 = d * (scales_lo_4[:, 1] | (scales_hi_6[:, 1] << 4))
d7 = d * (scales_lo_4[:, 2] | (scales_hi_6[:, 2] << 4))
d8 = d * (scales_lo_4[:, 3] | (scales_hi_6[:, 3] << 4))
return np.concatenate([
d1 * (qs_lo_4[:, 0] + (bits[:, :, 0] << 4)) - m1,
d2 * (qs_hi_4[:, 0] + (bits[:, :, 1] << 4)) - m2,
d3 * (qs_lo_4[:, 1] + (bits[:, :, 2] << 4)) - m3,
d4 * (qs_hi_4[:, 1] + (bits[:, :, 3] << 4)) - m4,
d5 * (qs_lo_4[:, 2] + (bits[:, :, 4] << 4)) - m5,
d6 * (qs_hi_4[:, 2] + (bits[:, :, 5] << 4)) - m6,
d7 * (qs_lo_4[:, 3] + (bits[:, :, 6] << 4)) - m7,
d8 * (qs_hi_4[:, 3] + (bits[:, :, 7] << 4)) - m8,
], axis=1)
def dequantize_q5_k_gpu(data):
pass
def dequantize_q6_k(data):
# C implementation
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.c#L2275
# C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L152
block_size = GGML_BLOCK_SIZES["Q6_K"]
num_blocks = len(data) // block_size
data_f16 = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, block_size // 2)
data_u8 = np.frombuffer(data, dtype=np.uint8).reshape(num_blocks, block_size)
data_i8 = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, block_size)
scales = data_f16[:, -1].reshape(num_blocks, 1).astype(np.float32)
# TODO use uint8 and cast later?
ql = data_u8[:, :128].astype(np.int16)
qh = data_u8[:, 128:192].astype(np.int16)
sc = data_i8[:, 192:208, np.newaxis].astype(np.float32)
# Unpack bits, subtraction requires signed data type
q1 = (ql[:, :32 ] & 0xF) | (((qh[:, :32] >> 0) & 3) << 4) - 32
q2 = (ql[:, 32:64 ] & 0xF) | (((qh[:, :32] >> 2) & 3) << 4) - 32
q3 = (ql[:, :32 ] >> 4) | (((qh[:, :32] >> 4) & 3) << 4) - 32
q4 = (ql[:, 32:64 ] >> 4) | (((qh[:, :32] >> 6) & 3) << 4) - 32
q5 = (ql[:, 64:96 ] & 0xF) | (((qh[:, 32:] >> 0) & 3) << 4) - 32
q6 = (ql[:, 96:128] & 0xF) | (((qh[:, 32:] >> 2) & 3) << 4) - 32
q7 = (ql[:, 64:96 ] >> 4) | (((qh[:, 32:] >> 4) & 3) << 4) - 32
q8 = (ql[:, 96:128] >> 4) | (((qh[:, 32:] >> 6) & 3) << 4) - 32
# Dequantize
return scales * np.concatenate([
sc[:, 0] * q1[:, :16],
sc[:, 1] * q1[:, 16:],
sc[:, 2] * q2[:, :16],
sc[:, 3] * q2[:, 16:],
sc[:, 4] * q3[:, :16],
sc[:, 5] * q3[:, 16:],
sc[:, 6] * q4[:, :16],
sc[:, 7] * q4[:, 16:],
sc[:, 8] * q5[:, :16],
sc[:, 9] * q5[:, 16:],
sc[:, 10] * q6[:, :16],
sc[:, 11] * q6[:, 16:],
sc[:, 12] * q7[:, :16],
sc[:, 13] * q7[:, 16:],
sc[:, 14] * q8[:, :16],
sc[:, 15] * q8[:, 16:],
], axis=1)
# @torch.jit.script
def dequantize_q6_k_gpu(data: np.ndarray, device:str = "cuda"):
block_size = GGML_BLOCK_SIZES["Q6_K"]
device = torch.device(device)
num_blocks = len(data) // block_size
data = np.frombuffer(data, dtype=data.dtype)
data = torch.from_numpy(data)
return KTransformersOps.dequantize_q6_k(data, 210, device)
def dequantize_q8_0(data):
# C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43
num_blocks = len(data) // GGML_BLOCK_SIZES["Q8_0"]
scales = np.frombuffer(data, dtype=np.float16).reshape(num_blocks, 1 + 16)[:, :1].astype(np.float32)
qs = np.frombuffer(data, dtype=np.int8).reshape(num_blocks, 2 + 32)[:, 2:]
return scales * qs
def dequantize_q8_0_gpu(data, device:str = "cuda"):
# C struct definition
# https://github.com/ggerganov/ggml/blob/fca1caafea7de9fbd7efc733b9818f9cf2da3050/src/ggml-quants.h#L43
num_blocks = len(data) // GGML_BLOCK_SIZES["Q8_0"]
device = torch.device(device)
data = np.frombuffer(data, dtype=data.dtype)
data = torch.from_numpy(data)
return KTransformersOps.dequantize_q8_0(data, 34, device)
def dequantize_f32(data):
return np.frombuffer(data, dtype=np.float32)
def dequantize_f32_gpu(data, device):
data = np.frombuffer(data, dtype=np.float32)
res = torch.from_numpy(data)
res_gpu = torch.empty_like(res, device=device)
res_gpu.copy_(res)
return res_gpu
def dequantize_f16(data):
return np.frombuffer(data, dtype=np.float16)
def dequantize_f16_gpu(data, device):
data = np.frombuffer(data, dtype=np.float16)
res = torch.from_numpy(data)
res_gpu = torch.empty_like(res, device=device)
res_gpu.copy_(res)
return res
GGML_DEQUANTIZE = {
"F32": dequantize_f32,
"F16": dequantize_f16,
"Q8_0": dequantize_q8_0,
"Q2_K": dequantize_q2_k,
"Q3_K": dequantize_q3_k,
"Q4_K": dequantize_q4_k,
"Q5_K": dequantize_q5_k,
"Q6_K": dequantize_q6_k,
}
GGML_DEQUANTIZE_GPU = {
"F32": dequantize_f32_gpu,
"F16": dequantize_f16_gpu,
"Q8_0": dequantize_q8_0_gpu,
"Q2_K": dequantize_q2_k_gpu,
"Q3_K": dequantize_q3_k_gpu,
"Q4_K": dequantize_q4_k_gpu,
"Q5_K": dequantize_q5_k_gpu,
"Q6_K": dequantize_q6_k_gpu,
}
def translate_name_to_gguf(name):
name = name.replace("lm_head.", "output.")
name = name.replace("model.embed_tokens.", "token_embd.")
name = name.replace("model.norm.", "output_norm.")
name = name.replace("model.layers.", "blk.")
name = name.replace(".input_layernorm", ".attn_norm")
name = name.replace(".mlp.down_proj", ".ffn_down")
name = name.replace(".mlp.gate_proj", ".ffn_gate")
name = name.replace(".mlp.up_proj", ".ffn_up")
name = name.replace(".post_attention_layernorm", ".ffn_norm")
name = name.replace(".self_attn.q_proj", ".attn_q")
name = name.replace(".self_attn.k_proj", ".attn_k")
name = name.replace(".self_attn.v_proj", ".attn_v")
name = name.replace(".self_attn.o_proj", ".attn_output")
name = name.replace(".self_attn.qkv_proj", ".attn_qkv")
name = name.replace(".self_attn.kv_a_proj_with_mqa", ".attn_kv_a_mqa")
name = name.replace(".self_attn.kv_a_layernorm", ".attn_kv_a_norm")
name = name.replace(".self_attn.kv_b_proj", ".attn_kv_b")
name = name.replace(".self_attn.q_a_proj", ".attn_q_a")
name = name.replace(".self_attn.q_a_layernorm", ".attn_q_a_norm")
name = name.replace(".self_attn.q_b_proj", ".attn_q_b")
name = name.replace(".shared_expert.", ".shared_experts.")
name = name.replace(".shared_expert_", ".shared_experts_")
name = name.replace(".gate_up_proj.", ".up_proj")
name = name.replace(".mlp.shared_experts.down_proj", ".ffn_down_shexp")
name = name.replace(".mlp.gate", ".ffn_gate_inp")
name = name.replace(".mlp.shared_experts.gate_proj", ".ffn_gate_shexp")
name = name.replace(".mlp.shared_experts.up_proj", ".ffn_up_shexp")
name = name.replace(".mlp.shared_experts_gate", ".ffn_gate_inp_shexp")
name = name.replace(".mlp.experts", "")
name = name.replace(".mlp.experts.ffn_down_exps", ".ffn_down_exps")
name = name.replace(".mlp.experts.ffn_gate_exps", ".ffn_gate_exps")
name = name.replace(".mlp.experts.ffn_up_exps", ".ffn_up_exps")
return name
if __name__ == '__main__':
gguf_path = '/mnt/data/model/DeepSeek-Coder-V2-GGUF-WJH'
loader = GGUFLoader(gguf_path)
loader.load_gguf_tensor('token_embd.weight')
from typing import Any, List, Optional, Set
class TextStreamer:
def __init__(self, tokenizer: "AutoTokenizer", skip_prompt: bool = False, **decode_kwargs):
self.tokenizer = tokenizer
self.skip_prompt = skip_prompt
self.decode_kwargs = decode_kwargs
# variables used in the streaming process
self.token_cache = []
self.print_len = 0
self.next_tokens_are_prompt = True
def reset(self):
self.token_cache = []
self.print_len = 0
def put(self, value)->Optional[str]:
"""
Receives tokens, decodes them, and prints them to stdout as soon as they form entire words.
"""
if not isinstance(value,int):
raise ValueError("TextStreamer only supports batch size 1, and int type input")
if self.skip_prompt and self.next_tokens_are_prompt:
self.next_tokens_are_prompt = False
return None
# Add the new token to the cache and decodes the entire thing.
self.token_cache.append(value)
text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True,**self.decode_kwargs)
# After the symbol for a new line, we flush the cache.
if text.endswith("\n"):
printable_text = text[self.print_len :]
self.reset()
# If the last token is a CJK character, we print the characters.
elif len(text) > 0 and self._is_chinese_char(ord(text[-1])):
printable_text = text[self.print_len :]
self.print_len += len(printable_text)
# Otherwise, prints until the last space char (simple heuristic to avoid printing incomplete words,
# which may change with the subsequent token -- there are probably smarter ways to do this!)
else:
printable_text = text[self.print_len : text.rfind(" ") + 1]
self.print_len += len(printable_text)
return printable_text
def end(self)->Optional[str]:
"""Flushes any remaining cache and prints a newline to stdout."""
# Flush the cache, if it exists
if len(self.token_cache) > 0:
text = self.tokenizer.decode(self.token_cache, skip_special_tokens=True, **self.decode_kwargs)
printable_text = text[self.print_len :]
self.reset()
else:
printable_text = ""
self.next_tokens_are_prompt = True
return printable_text
def _is_chinese_char(self, cp):
"""Checks whether CP is the codepoint of a CJK character."""
# This defines a "chinese character" as anything in the CJK Unicode block:
# https://en.wikipedia.org/wiki/CJK_Unified_Ideographs_(Unicode_block)
#
# Note that the CJK Unicode block is NOT all Japanese and Korean characters,
# despite its name. The modern Korean Hangul alphabet is a different block,
# as is Japanese Hiragana and Katakana. Those alphabets are used to write
# space-separated words, so they are not treated specially and handled
# like the all of the other languages.
if (
(cp >= 0x4E00 and cp <= 0x9FFF)
or (cp >= 0x3400 and cp <= 0x4DBF) #
or (cp >= 0x20000 and cp <= 0x2A6DF) #
or (cp >= 0x2A700 and cp <= 0x2B73F) #
or (cp >= 0x2B740 and cp <= 0x2B81F) #
or (cp >= 0x2B820 and cp <= 0x2CEAF) #
or (cp >= 0xF900 and cp <= 0xFAFF)
or (cp >= 0x2F800 and cp <= 0x2FA1F) #
): #
return True
return False
\ No newline at end of file
#!/usr/bin/env python
# coding=utf-8
'''
Description :
Author : Boxin Zhang, Azure-Tang
Version : 0.1.0
Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
'''
import torch
from torch import nn
import itertools
import time
import enum
from ktransformers.util.custom_gguf import translate_name_to_gguf
from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.operators import base_operator
from ktransformers.models.custom_cache import StaticCache
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from ktransformers.util.textstream import TextStreamer
def set_module(model, submodule_key, module):
tokens = submodule_key.split('.')
sub_tokens = tokens[:-1]
cur_mod = model
for s in sub_tokens:
if hasattr(cur_mod, s):
cur_mod = getattr(cur_mod, s)
else: # nn.ModuleList or nn.ModuleList
cur_mod=cur_mod[int(s)]
if hasattr(cur_mod, tokens[-1]):
setattr(cur_mod, tokens[-1], module)
else: # nn.ModuleList or nn.ModuleList
cur_mod[int(tokens[-1])] = module
def set_param(module: nn.Module, name: str, weights: torch.Tensor):
param=nn.parameter.Parameter(weights, requires_grad=False)
if isinstance(module, nn.Linear) and len(weights.shape)==1:
param.unsqueeze_(0)
setattr(module, name, param)
def load_cur_state_dict(module: nn.Module, gguf_loader: GGUFLoader, prefix: str = ""):
prefix = prefix.replace("orig_module.", "")
persistent_buffers = {k: v for k, v in module._buffers.items() if k not in module._non_persistent_buffers_set}
local_name_params = itertools.chain(module._parameters.items(), persistent_buffers.items())
local_state = {k: v for k, v in local_name_params if v is not None}
for name, param in local_state.items():
key = prefix + name
translated_key = translate_name_to_gguf(key)
print("default loading weights", key, translated_key)
if translated_key in gguf_loader.tensor_file_map:
target_dtype = torch.get_default_dtype()
device = "cpu" if "embd" in translated_key else "cuda"
weights = gguf_loader.load_gguf_tensor(translated_key, device = device).to(dtype = target_dtype)
set_param(module, name, weights)
del weights
else:
#print(load_config.tensor_file_map.keys())
raise Exception(f"can't fand {translated_key} in GGUF file!")
def load_weights(module:nn.Module, gguf_loader:GGUFLoader, prefix='', return_when_injected:bool = False, only_load_injected:bool = False):
# print(f"recursively loading weights {prefix},{return_when_injected=}, {only_load_injected=}")
if not isinstance(module, base_operator.BaseInjectedModule):
load_cur_state_dict(module, gguf_loader, prefix)
for name, child in module._modules.items():
load_weights(child, gguf_loader, prefix+name+".")
else:
module.load()
def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000):
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch._dynamo.config.suppress_errors = True
batch_size, seq_length = inputs.shape
torch_device = inputs.device
tokens = []
def decode_one_tokens(cuda_graph_runner, cur_token, position_ids, cache_position, past_key_values):
logits = cuda_graph_runner(cur_token, position_ids, cache_position)
past_key_values.change_seq_length(1)
"""
with torch.cuda.stream(custom_stream):
logits=model(cur_token,
position_ids=position_ids,
cache_position=cache_position,
past_key_values=past_key_values,
return_dict=False, use_cache=True)[0]
#"""
torch.cuda.synchronize()
#print(logits)
next_token_scores = logits_warper(inputs, logits[:, -1, :])
if generation_config.do_sample:
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_token = torch.argmax(next_token_scores, dim=-1)
return next_token
with torch.no_grad():
stream = TextStreamer(tokenizer)
past_key_values = StaticCache(
config = model.config, max_batch_size = 1, max_cache_len = seq_length + max_new_tokens, device = torch_device, dtype = model.dtype
)
cache_position = torch.arange(seq_length, device=torch_device)
generated_ids = torch.zeros(
batch_size, seq_length + max_new_tokens + 1, dtype=torch.int, device=torch_device
)
generated_ids[:, cache_position] = inputs.to(torch_device).to(torch.int)
past_key_values.cur_idx=cache_position
start_time = time.time()
#custom_stream = torch.cuda.Stream()
inputs_embeds = model.model.embed_tokens(inputs.to("cpu")).to("cuda")
logits = model(
inputs_embeds = inputs_embeds, cache_position=cache_position, past_key_values=past_key_values, return_dict=False, use_cache=True
)[0][:,-1,:].unsqueeze(0).clone()
generation_config, model_kwargs = model._prepare_generation_config(
None, max_length=max_new_tokens,
do_sample=True, top_k=5, top_p=0.85, temperature=0.1 # change this to modify generate config
)
try: # transformers==4.43
logits_warper = (
model._get_logits_warper(generation_config,device=inputs.device) if generation_config.do_sample else None
)
except:
logits_warper = (
model._get_logits_warper(generation_config) if generation_config.do_sample else None
)
next_token_scores = logits_warper(inputs, logits[:, -1, :])
if generation_config.do_sample:
probs = nn.functional.softmax(next_token_scores, dim=-1)
next_token = torch.multinomial(probs, num_samples=1).squeeze(1)
else:
next_token = torch.argmax(next_token_scores, dim=-1)
first_token_time = time.time() - start_time
prefill_count = seq_length
prefill_time = first_token_time
print(stream.put(next_token.item()), end="", flush=True)
generated_ids[:, seq_length] = next_token
tokens.append(next_token)
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
cache_position = torch.tensor([seq_length], device=torch_device)
position_ids = cache_position.unsqueeze(0)
seq_length += 1
cuda_graph_runner = CUDAGraphRunner()
cuda_graph_runner.capture(model, next_token.unsqueeze(0), position_ids, cache_position, past_key_values, return_dict=False, use_cache=True)
start_time = time.time()
for _ in range(1, max_new_tokens):
next_token = decode_one_tokens(cuda_graph_runner, next_token.unsqueeze(0), position_ids, cache_position, past_key_values)
inputs = torch.cat((inputs, next_token.unsqueeze(0)), dim=-1)
generated_ids[:, cache_position] = next_token.int()
tokens.append(next_token.int())
seq_length += 1
if next_token[0].item() == tokenizer.eos_token_id:
print(stream.end(), end="", flush=True)
break
else:
print(stream.put(next_token.item()), end="", flush=True)
cache_position += 1
position_ids = cache_position.unsqueeze(0)
total_time = time.time() - start_time
tokens_generated = len(tokens)
tokens_per_second = tokens_generated / total_time
print("")
print(f"prompt eval count: {prefill_count} token(s)")
print(f"prompt eval duration: {prefill_time}s")
print(f"prompt eval rate: {prefill_count/prefill_time} tokens/s")
print(f"eval count: {tokens_generated} token(s)")
print(f"eval duration: {total_time}s")
print(f"eval rate: {tokens_per_second} tokens/s")
return tokens
class InferenceState(enum.Enum):
UNLOAD = 0
PREFILL = 1
GENERATE = 2
RESTORE = 3
> 1%
last 2 versions
not dead
not ie 11
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment