Commit 412055d4 authored by Atream's avatar Atream
Browse files

[feature] experts can be injected using CPUInfer

[fix] fix ktransformers interface when use new CUDAGraphRunner
[fix] fix YAML and optimize logic, the top rule has the highest priority
parent 80815dbc
import sys, os
from typing import Any
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug"))
import cpuinfer_ext
from ktransformers.server.config.config import Config
class CPUInfer:
cpu_infer = None
def __init__(self, cpu_infer:int = Config().cpu_infer):
if CPUInfer.cpu_infer is None:
CPUInfer.cpu_infer = cpuinfer_ext.CPUInfer(cpu_infer)
def __getattribute__(self, __name: str) -> Any:
return CPUInfer.cpu_infer.__getattribute__(__name)
def __setattr__(self, __name: str, __value: Any) -> None:
return CPUInfer.cpu_infer.__setattr__(__name, __value)
\ No newline at end of file
...@@ -33,6 +33,7 @@ from transformers.configuration_utils import PretrainedConfig ...@@ -33,6 +33,7 @@ from transformers.configuration_utils import PretrainedConfig
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from ktransformers.operators.linear import QuantizedLinearMarlin, QuantizedLinearTorch, KTransformerLinear from ktransformers.operators.linear import QuantizedLinearMarlin, QuantizedLinearTorch, KTransformerLinear
import time import time
from ktransformers.operators.cpuinfer import CPUInfer
# class Base(BaseInjectedModule, ABC): # class Base(BaseInjectedModule, ABC):
...@@ -117,7 +118,7 @@ class MLPCPUExperts(MLPExpertsBase): ...@@ -117,7 +118,7 @@ class MLPCPUExperts(MLPExpertsBase):
output_cpu:Tensor = None output_cpu:Tensor = None
output_gpu_map:dict = {} # Manage output tensor buffer on different gpu output_gpu_map:dict = {} # Manage output tensor buffer on different gpu
#stream_map:dict = {} # Manage cuda stream on different gpu #stream_map:dict = {} # Manage cuda stream on different gpu
CPU_INFER = cpuinfer_ext.CPUInfer(Config().cpu_infer) CPU_INFER = CPUInfer(Config().cpu_infer)
def __init__( def __init__(
self, self,
key: str, key: str,
...@@ -126,7 +127,7 @@ class MLPCPUExperts(MLPExpertsBase): ...@@ -126,7 +127,7 @@ class MLPCPUExperts(MLPExpertsBase):
n_routed_experts: int, n_routed_experts: int,
orig_module: nn.Module = None, orig_module: nn.Module = None,
device: str = "cpu", device: str = "cpu",
out_device: str = "cuda", # this device mean which device the output should on out_device: str = "cuda", # this device mean which device the output should on. TODO: support cpu.
**kwargs **kwargs
): ):
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs) super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
...@@ -135,7 +136,6 @@ class MLPCPUExperts(MLPExpertsBase): ...@@ -135,7 +136,6 @@ class MLPCPUExperts(MLPExpertsBase):
self.out_device = out_device self.out_device = out_device
def load(self, w: dict | nn.Parameter | tuple | None = None, device:str|None = None, warmup:bool = False): def load(self, w: dict | nn.Parameter | tuple | None = None, device:str|None = None, warmup:bool = False):
with torch.device(self.out_device):
if device: if device:
assert device.lower() == "cpu", "MLPCPUExperts can only be loaded on CPU, Parameter \"device\" can be cpu or None." assert device.lower() == "cpu", "MLPCPUExperts can only be loaded on CPU, Parameter \"device\" can be cpu or None."
if w is None: w = self.load_weights()[self.key] if w is None: w = self.load_weights()[self.key]
......
...@@ -11,8 +11,9 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved. ...@@ -11,8 +11,9 @@ Copyright (c) 2024 by KVCache.AI, All Rights Reserved.
''' '''
import ctypes
import torch import torch
from torch import nn from torch import Tensor, nn
import KTransformersOps import KTransformersOps
from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.util.utils import InferenceState from ktransformers.util.utils import InferenceState
...@@ -25,7 +26,13 @@ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marl ...@@ -25,7 +26,13 @@ from ktransformers.ktransformers_ext.operators.custom_marlin.quantize.utils.marl
from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import sys, os
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Release"))
sys.path.append(os.path.join(os.path.dirname(__file__), "..", "ktransformers_ext", "build", "Debug"))
import cpuinfer_ext
from ktransformers.operators.cpuinfer import CPUInfer
from ktransformers.server.config.config import Config
#class QuantizedLinearBase(BaseInjectedModule, ABC): #class QuantizedLinearBase(BaseInjectedModule, ABC):
class QuantizedLinearBase(ABC): class QuantizedLinearBase(ABC):
...@@ -118,6 +125,7 @@ class QuantizedLinearTorch(QuantizedLinearBase): ...@@ -118,6 +125,7 @@ class QuantizedLinearTorch(QuantizedLinearBase):
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
dtype = x.dtype dtype = x.dtype
out_device = x.device out_device = x.device
# TODO: support CUDA Graph when using cpu, but CPUInfer is recommended.
x = x.to(device=self.device, dtype=self.dtype) x = x.to(device=self.device, dtype=self.dtype)
x = x @ self.w x = x @ self.w
if self.has_bias: if self.has_bias:
...@@ -244,9 +252,112 @@ class QuantizedLinearMarlin(QuantizedLinearBase): ...@@ -244,9 +252,112 @@ class QuantizedLinearMarlin(QuantizedLinearBase):
self.sort_indices = None self.sort_indices = None
self.workspace = None self.workspace = None
class QuantizedLinearCPUInfer(QuantizedLinearBase):
CPU_INFER = CPUInfer(Config().cpu_infer)
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module = None,
device: str = "cpu",
out_device: str = "cuda", # this device mean which device the output should on. TODO: support cpu.
stride = 16,
group_max_len = 1024,
**kwargs,
):
super().__init__(key, gguf_loader, config, orig_module, device, **kwargs)
self.has_bias = False
self.dtype = torch.get_default_dtype()
self.w = None
self.has_bias = False
self.stride = stride
self.group_max_len = group_max_len
self.out_device = out_device
def forward(self, x: torch.Tensor) -> torch.Tensor:
origin_shape = x.shape # [batch_size, q_len, hidden_size]
if origin_shape[1] == 1:
out_device = x.device
self.input_tensor_cpu.copy_(x, non_blocking=True)
qlen = origin_shape[1]
QuantizedLinearCPUInfer.CPU_INFER.submit_with_cuda_stream(
torch.cuda.current_stream().cuda_stream,
self.linear.forward(
qlen,
self.input_tensor_cpu.data_ptr(),
self.output_cpu.data_ptr()
)
)
QuantizedLinearCPUInfer.CPU_INFER.sync_with_cuda_stream(torch.cuda.current_stream().cuda_stream)
self.output_gpu.copy_(self.output_cpu, non_blocking=True)
if self.has_bias:
self.output_gpu += self.bias
return self.output_gpu
else:
dtype = x.dtype
out_device = x.device
x = x.to(device=self.device)
qlen = origin_shape[1]
output_shape = (*origin_shape[:-1], self.out_features)
output = torch.empty(output_shape, device=x.device, dtype=x.dtype)
QuantizedLinearCPUInfer.CPU_INFER.submit(
self.linear.forward(
qlen,
x.data_ptr(),
output.data_ptr()
)
)
QuantizedLinearCPUInfer.CPU_INFER.sync()
if self.has_bias:
output = output + self.bias
output = output.to(dtype=dtype, device=out_device)
return output
def load(self, w: dict | nn.Parameter | tuple | None = None, device: str|None = None, warmup:bool = True):
print(f"loading {self.key} to {self.device} using CPUInfer")
if device is None: device = self.device
self.load_weights(w=w, device=device)
if self.bias is not None:
self.has_bias = True
self.bias = self.bias.to(device)
weight_ptr = ctypes.addressof(
ctypes.cast(self.weight.ctypes.data, ctypes.POINTER(ctypes.c_uint64)).contents
)
config = cpuinfer_ext.linear.LinearConfig(self.in_features, self.out_features, self.stride, self.group_max_len, weight_ptr, self.weight_type, 30)
self.linear = cpuinfer_ext.linear.Linear(config)
if warmup:
QuantizedLinearCPUInfer.CPU_INFER.submit(self.linear.warm_up())
QuantizedLinearCPUInfer.CPU_INFER.sync()
self.input_tensor_cpu = torch.zeros((1, 1, self.in_features), device="cpu", pin_memory=True)
self.output_cpu = torch.zeros((1, 1, self.out_features), device="cpu", pin_memory=True, dtype=torch.bfloat16)
self.output_gpu = torch.zeros((1, 1, self.out_features), device=self.out_device)
def load_weights(self, w: dict | nn.Parameter | tuple | None = None, device: str = "cpu"):
if self.key + ".weight" in self.gguf_loader.tensor_info:
if self.key + ".bias" in self.gguf_loader.tensor_file_map:
self.weight = self.gguf_loader.get_mmap_tensor(self.key + ".weight")
self.weight_type = self.gguf_loader.tensor_info[self.key + ".weight"]["ggml_type"]
self.bias = self.gguf_loader.load_gguf_tensor(self.key + ".bias", device=device)
else:
self.weight = self.gguf_loader.get_mmap_tensor(self.key + ".weight")
self.weight_type = self.gguf_loader.tensor_info[self.key + ".weight"]["ggml_type"]
self.bias = None
else:
raise ValueError(f"Linear {self.key} not found in gguf_loader")
def unload(self):
if self.w is not None:
self.w = None
if self.has_bias:
self.bias = None
LINEAR_MAP = { LINEAR_MAP = {
"QuantizedLinearMarlin": QuantizedLinearMarlin, "QuantizedLinearMarlin": QuantizedLinearMarlin,
"QuantizedLinearTorch": QuantizedLinearTorch, "QuantizedLinearTorch": QuantizedLinearTorch,
"QuantizedLinearCPUInfer": QuantizedLinearCPUInfer
} }
class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase): class KTransformerLinear(BaseInjectedModule, QuantizedLinearBase):
......
...@@ -58,7 +58,6 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p ...@@ -58,7 +58,6 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p
#print("gen_optimize_config", prefix, module_name, translated_name) #print("gen_optimize_config", prefix, module_name, translated_name)
recursive = True recursive = True
for rule in rule_list: for rule in rule_list:
#print(rule)
match_meta = rule["match"] match_meta = rule["match"]
if "class" not in match_meta and "name" not in match_meta: if "class" not in match_meta and "name" not in match_meta:
raise Exception("match must have at least one of \"class\" and \"name\"") raise Exception("match must have at least one of \"class\" and \"name\"")
...@@ -87,6 +86,7 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p ...@@ -87,6 +86,7 @@ def gen_optimize_config(module: nn.Module, out_data: Mapping, rule_list: List, p
out_data[module_name]["kwargs"].update(copy.deepcopy(replace_meta["kwargs"]) if "kwargs" in replace_meta else dict()) out_data[module_name]["kwargs"].update(copy.deepcopy(replace_meta["kwargs"]) if "kwargs" in replace_meta else dict())
if "recursive" in rule: if "recursive" in rule:
recursive = bool(rule["recursive"]) recursive = bool(rule["recursive"])
break
if module_name not in out_data: if module_name not in out_data:
out_data[module_name]= { out_data[module_name]= {
...@@ -127,5 +127,6 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo ...@@ -127,5 +127,6 @@ def optimize_and_load_gguf(module: nn.Module, rule_file: str, gguf_path: str, mo
with torch.device("meta"): with torch.device("meta"):
inject(module, optimize_config, model_config, gguf_loader) inject(module, optimize_config, model_config, gguf_loader)
load_weights(module, gguf_loader) load_weights(module, gguf_loader)
model_config.gguf_loader = gguf_loader module.gguf_loader = gguf_loader
del_meta(module) del_meta(module)
torch.cuda.empty_cache()
- match:
name: "^model\\.layers\\.([0-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "(^model\\.layers\\.([1][0-9])\\.)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "(^model\\.layers\\.([2][0-9])\\.)"
replace:
class: "default"
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
- match:
name: "(^model\\.layers\\.([345][0-9])\\.)|(^model.norm)|(^lm_head)"
replace:
class: "default"
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
- match: - match:
name: "^model.embed_tokens" name: "^model.embed_tokens"
replace: replace:
...@@ -69,7 +40,7 @@ ...@@ -69,7 +40,7 @@
prefill_device: "cuda:3" prefill_device: "cuda:3"
- match: - match:
name: "^model\\.layers\\.([1][0-9])\\.(?!self_attn).*$" # regular expression name: "^model\\.layers\\.([0-9])\\.(?!self_attn).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously
replace: replace:
class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types
...@@ -226,3 +197,32 @@ ...@@ -226,3 +197,32 @@
10: "cuda:1" 10: "cuda:1"
20: "cuda:2" 20: "cuda:2"
30: "cuda:3" 30: "cuda:3"
- match:
name: "^model\\.layers\\.([0-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "(^model\\.layers\\.([1][0-9])\\.)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match:
name: "(^model\\.layers\\.([2][0-9])\\.)"
replace:
class: "default"
kwargs:
generate_device: "cuda:2"
prefill_device: "cuda:2"
- match:
name: "(^model\\.layers\\.([345][0-9])\\.)|(^model.norm)|(^lm_head)"
replace:
class: "default"
kwargs:
generate_device: "cuda:3"
prefill_device: "cuda:3"
\ No newline at end of file
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "(^model\\.layers\\.([345][0-9])\\.)|(model.norm)|(lm_head)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match: - match:
name: "^model.embed_tokens" name: "^model.embed_tokens"
replace: replace:
...@@ -124,3 +108,19 @@ ...@@ -124,3 +108,19 @@
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
transfer_map: transfer_map:
30: "cuda:1" 30: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "(^model\\.layers\\.([345][0-9])\\.)|(model.norm)|(lm_head)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
\ No newline at end of file
- match: - match:
name: "^model\\.layers\\..*\\.|^lm_head" class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding
replace: replace:
class: "default" class: ktransformers.operators.RoPE.YarnRotaryEmbedding
kwargs: kwargs:
generate_device: "cuda" generate_device: "cuda"
prefill_device: "cuda" prefill_device: "cuda"
- match: #- match:
class: ktransformers.models.modeling_deepseek.DeepseekV2YarnRotaryEmbedding # name: "^model\\.layers\\.([1-5][0-9])\\.mlp\\.shared_experts.*$" # regular expression
replace: # class: torch.nn.Linear # only match modules matching name and class simultaneously
class: ktransformers.operators.RoPE.YarnRotaryEmbedding # replace:
# class: ktransformers.operators.linear.KTransformerLinear # optimized Kernel on quantized data types
# kwargs:
# generate_device: "cpu"
# prefill_device: "cuda"
# generate_op: "QuantizedLinearCPUInfer"
# prefill_op: "QuantizedLinearTorch"
# out_device: "cuda"
- match: - match:
name: "^model\\.layers\\.(?!.*self_attn).*$" # regular expression name: "^model\\.layers\\.(?!.*self_attn).*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously
...@@ -24,6 +31,9 @@ ...@@ -24,6 +31,9 @@
class: ktransformers.models.modeling_deepseek.DeepseekV2MoE class: ktransformers.models.modeling_deepseek.DeepseekV2MoE
replace: replace:
class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function class: ktransformers.operators.experts.DeepseekV2MoEInjected # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match: - match:
name: "^model\\.layers\\..*\\.mlp\\.experts$" name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace: replace:
...@@ -39,11 +49,16 @@ ...@@ -39,11 +49,16 @@
name: "^model\\.layers\\..*\\.self_attn$" name: "^model\\.layers\\..*\\.self_attn$"
replace: replace:
class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation class: ktransformers.operators.attention.DeepseekV2AttentionInjected # optimized MLA implementation
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match: - match:
name: "^model$" name: "^model$"
replace: replace:
class: "ktransformers.operators.layer_wise_prefill.DeepseekV2ModelKTransformers" class: "ktransformers.operators.layer_wise_prefill.DeepseekV2ModelKTransformers"
kwargs: kwargs:
generate_device: "cuda"
prefill_device: "cuda"
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
- match: - match:
name: "^model.embed_tokens" name: "^model.embed_tokens"
......
- match:
name: "^model\\.layers\\.(0|[1-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)|(lm_head)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match: - match:
name: "^model.embed_tokens" name: "^model.embed_tokens"
replace: replace:
...@@ -124,3 +108,19 @@ ...@@ -124,3 +108,19 @@
per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill per_layer_prefill_intput_threshold: 0 # 0 is close layer wise prefill
transfer_map: transfer_map:
10: "cuda:1" 10: "cuda:1"
- match:
name: "^model\\.layers\\.(0|[1-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "(^model\\.layers\\.([12][0-9])\\.)|(model.norm)|(lm_head)"
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
\ No newline at end of file
- match: - match:
name: "^model\\.layers\\..*\\." class: ktransformers.models.modeling_mixtral.MixtralRotaryEmbedding
replace: replace:
class: "default" class: ktransformers.operators.RoPE.RotaryEmbedding
kwargs: kwargs:
generate_device: "cuda" generate_device: "cuda"
prefill_device: "cuda" prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_mixtral.MixtralRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.RotaryEmbedding
- match: - match:
name: "^model\\.layers\\..*$" name: "^model\\.layers\\..*$"
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously
...@@ -43,3 +39,11 @@ ...@@ -43,3 +39,11 @@
kwargs: kwargs:
generate_device: "cpu" generate_device: "cpu"
prefill_device: "cpu" prefill_device: "cpu"
- match:
name: "^model\\.layers\\..*\\."
replace:
class: "default"
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
\ No newline at end of file
- match:
name: "^model\\.layers\\.([012])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match: - match:
name: "^model\\.layers\\.([012])\\." name: "^model\\.layers\\.([012])\\."
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
...@@ -41,13 +34,6 @@ ...@@ -41,13 +34,6 @@
out_device: "cuda:0" out_device: "cuda:0"
recursive: False # don't recursively inject submodules of this module recursive: False # don't recursively inject submodules of this module
- match:
name: "^model\\.layers\\.([12][0-9]|[3-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
- match: - match:
name: "^model\\.layers\\.([12][0-9]|[3-9])\\." name: "^model\\.layers\\.([12][0-9]|[3-9])\\."
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
...@@ -109,3 +95,18 @@ ...@@ -109,3 +95,18 @@
transfer_map: transfer_map:
3: "cuda:1" 3: "cuda:1"
- match:
name: "^model\\.layers\\.([012])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:0"
prefill_device: "cuda:0"
- match:
name: "^model\\.layers\\.([12][0-9]|[3-9])\\."
replace:
class: "default"
kwargs:
generate_device: "cuda:1"
prefill_device: "cuda:1"
\ No newline at end of file
- match: - match:
name: "^model\\.layers\\..*\\." class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
replace: replace:
class: "default" class: ktransformers.operators.RoPE.RotaryEmbedding
kwargs: kwargs:
generate_device: "cuda" generate_device: "cuda"
prefill_device: "cuda" prefill_device: "cuda"
- match:
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeRotaryEmbedding
replace:
class: ktransformers.operators.RoPE.RotaryEmbedding
- match: - match:
name: "^model\\.layers\\..*$" # regular expression name: "^model\\.layers\\..*$" # regular expression
class: torch.nn.Linear # only match modules matching name and class simultaneously class: torch.nn.Linear # only match modules matching name and class simultaneously
...@@ -24,6 +20,9 @@ ...@@ -24,6 +20,9 @@
class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock class: ktransformers.models.modeling_qwen2_moe.Qwen2MoeSparseMoeBlock
replace: replace:
class: ktransformers.operators.experts.Qwen2MoeSparseMoeBlockInjected # mlp module with custom forward function class: ktransformers.operators.experts.Qwen2MoeSparseMoeBlockInjected # mlp module with custom forward function
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
- match: - match:
name: "^model\\.layers\\..*\\.mlp\\.experts$" name: "^model\\.layers\\..*\\.mlp\\.experts$"
replace: replace:
...@@ -49,3 +48,10 @@ ...@@ -49,3 +48,10 @@
kwargs: kwargs:
generate_device: "cpu" generate_device: "cpu"
prefill_device: "cpu" prefill_device: "cpu"
- match:
name: "^model\\.layers\\..*\\."
replace:
class: "default"
kwargs:
generate_device: "cuda"
prefill_device: "cuda"
\ No newline at end of file
...@@ -6,6 +6,7 @@ from ktransformers.optimize.optimize import optimize_and_load_gguf ...@@ -6,6 +6,7 @@ from ktransformers.optimize.optimize import optimize_and_load_gguf
from ktransformers.models.custom_cache import StaticCache from ktransformers.models.custom_cache import StaticCache
from ktransformers.util.cuda_graph_runner import CUDAGraphRunner from ktransformers.util.cuda_graph_runner import CUDAGraphRunner
from ktransformers.local_chat import custom_models, default_optimize_rules from ktransformers.local_chat import custom_models, default_optimize_rules
from ktransformers.util.utils import get_device
class KTransformersThreadContext(TransformersThreadContext): class KTransformersThreadContext(TransformersThreadContext):
...@@ -48,8 +49,11 @@ class KTransformersInterface(TransformersInterface): ...@@ -48,8 +49,11 @@ class KTransformersInterface(TransformersInterface):
def decode_one_tokens(self): def decode_one_tokens(self):
if not hasattr(self, "cuda_graph_runner"): if not hasattr(self, "cuda_graph_runner"):
device_map = self.model.gguf_loader.tensor_device_map
torch_device = get_device('blk.0.self_attn', device_map)
torch_device = "cuda:0" if torch_device == "cuda" else torch_device
self.cuda_graph_runner = CUDAGraphRunner() self.cuda_graph_runner = CUDAGraphRunner()
self.cuda_graph_runner.capture(self.model, self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position, self.cache, return_dict=False, use_cache=True) self.cuda_graph_runner.capture(self.model, self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position, self.cache, main_device=torch_device, return_dict=False, use_cache=True)
if hasattr(self, "cuda_graph_runner"): if hasattr(self, "cuda_graph_runner"):
logits = self.cuda_graph_runner(self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position) logits = self.cuda_graph_runner(self.current_ids, self.active_cache_position.unsqueeze(0), self.active_cache_position)
......
...@@ -89,7 +89,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud ...@@ -89,7 +89,7 @@ def prefill_and_generate(model, tokenizer, inputs, max_new_tokens=10000, use_cud
os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["TOKENIZERS_PARALLELISM"] = "false"
torch._dynamo.config.suppress_errors = True torch._dynamo.config.suppress_errors = True
batch_size, seq_length = inputs.shape batch_size, seq_length = inputs.shape
device_map = model.config.gguf_loader.tensor_device_map device_map = model.gguf_loader.tensor_device_map
torch_device = get_device('blk.0.self_attn', device_map) torch_device = get_device('blk.0.self_attn', device_map)
torch_device = "cuda:0" if torch_device == "cuda" else torch_device torch_device = "cuda:0" if torch_device == "cuda" else torch_device
inputs = inputs.to(torch_device) inputs = inputs.to(torch_device)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment