Commit 907251c7 authored by Azure's avatar Azure
Browse files

done support deepseekv3

parent f748cd29
...@@ -23,7 +23,7 @@ from ktransformers.operators.base_operator import BaseInjectedModule ...@@ -23,7 +23,7 @@ from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.util.utils import InferenceState from ktransformers.util.utils import InferenceState
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
import torch
# Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe # Copied from transformers.models.mixtral.modeling_mixtral.MixtralRotaryEmbedding with Mixtral->Qwen2Moe
class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding): class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
...@@ -56,6 +56,57 @@ class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding): ...@@ -56,6 +56,57 @@ class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
) )
class RotaryEmbeddingV3(BaseInjectedModule):
def __init__(
self,
key: str,
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
# device: str = "cuda",
generate_device: str = "cuda",
prefill_device: str = "cuda",
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
)
self.generate_device = generate_device
self.prefill_device = prefill_device
@torch.no_grad()
def forward(self, x, position_ids):
# x: [bs, num_attention_heads, seq_len, head_size]
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
position_ids_expanded = position_ids[:, None, :].float()
# Force float32 since bfloat16 loses precision on long contexts
# See https://github.com/huggingface/transformers/pull/29285
device_type = x.device.type
device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False):
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos()
sin = emb.sin()
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
def load(self):
self._init(
dim=self.config.qk_rope_head_dim,
max_position_embeddings=self.config.max_position_embeddings,
base=self.config.rope_theta,
device=self.device,
)
def _init(self, dim, max_position_embeddings, base, device, scaling_factor=1.0):
self.scaling_factor = scaling_factor
self.dim = dim
self.max_position_embeddings = max_position_embeddings
self.base = base
self.inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
# self.register_buffer("inv_freq", inv_freq, persistent=False)
# For BC we register cos and sin cached
self.max_seq_len_cached = max_position_embeddings
class RotaryEmbeddingV2(BaseInjectedModule, LlamaRotaryEmbedding): class RotaryEmbeddingV2(BaseInjectedModule, LlamaRotaryEmbedding):
def __init__( def __init__(
self, self,
......
...@@ -151,7 +151,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention): ...@@ -151,7 +151,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
attn_output = self.o_proj(attn_output) attn_output = self.o_proj(attn_output)
return attn_output, attn_weights return attn_output, attn_weights, past_key_value
def forward( def forward(
self, self,
...@@ -220,7 +220,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention): ...@@ -220,7 +220,7 @@ class KDeepseekV3Attention(BaseInjectedModule, DeepseekV3Attention):
attn_output = torch.cat((attn_output, cur_output), dim=-2) attn_output = torch.cat((attn_output, cur_output), dim=-2)
attn_weight = torch.cat((attn_weight, cur_attn_weight), dim=-2) attn_weight = torch.cat((attn_weight, cur_attn_weight), dim=-2)
return attn_output, attn_weight return attn_output, attn_weight, past_key_value
class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention): class KDeepseekV2Attention(BaseInjectedModule, DeepseekV2Attention):
"""Multi-headed attention from 'Attention Is All You Need' paper""" """Multi-headed attention from 'Attention Is All You Need' paper"""
......
...@@ -734,7 +734,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE): ...@@ -734,7 +734,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
identity = hidden_states identity = hidden_states
orig_shape = hidden_states.shape orig_shape = hidden_states.shape
sequence_length = orig_shape[1] sequence_length = orig_shape[1]
topk_idx, topk_weight, router_logits= self.gate(hidden_states) topk_idx, topk_weight = self.gate(hidden_states)
hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
# only for generate phase # only for generate phase
...@@ -745,7 +745,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE): ...@@ -745,7 +745,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0) y = self.experts.generate_experts.sync_for_one_decode().unsqueeze(0)
y += y_ y += y_
y.resize_(*orig_shape) y.resize_(*orig_shape)
return y, router_logits return y
if self.config.n_shared_experts is not None: if self.config.n_shared_experts is not None:
y_ = self.shared_experts(identity).squeeze(0) y_ = self.shared_experts(identity).squeeze(0)
...@@ -768,7 +768,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE): ...@@ -768,7 +768,7 @@ class KDeepseekV3MoE(BaseInjectedModule, DeepseekV3MoE):
) )
if self.config.n_shared_experts is not None: if self.config.n_shared_experts is not None:
y += y_ y += y_
return y, router_logits return y
@torch.no_grad() @torch.no_grad()
def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor: def moe_on_cpuinfer(self, x: torch.Tensor, topk_ids: torch.Tensor, topk_weight: torch.Tensor) -> torch.Tensor:
......
...@@ -16,9 +16,6 @@ from cpuinfer_ext.moe import MOEConfig, MOE ...@@ -16,9 +16,6 @@ from cpuinfer_ext.moe import MOEConfig, MOE
import ctypes import ctypes
from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.util.custom_gguf import GGUFLoader
from ktransformers.models.modeling_deepseek_v3 import DeepseekV3TopkRouter
from ktransformers.util.utils import InferenceState
from ktransformers.server.config.config import Config
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
...@@ -102,6 +99,8 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase): ...@@ -102,6 +99,8 @@ class KMoEGate(BaseInjectedModule, KMoEGateBase):
): ):
BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) BaseInjectedModule.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs) KMoEGateBase.__init__(self, key, gguf_loader, config, orig_module, generate_device, **kwargs)
self.generate_device = generate_device
self.prefill_device = prefill_device
def forward(self, hidden_states) -> torch.Tensor: def forward(self, hidden_states) -> torch.Tensor:
return self.orig_module.forward(hidden_states) return self.orig_module.forward(hidden_states)
......
...@@ -625,6 +625,13 @@ class KDeepseekV2Model(BaseInjectedModule): ...@@ -625,6 +625,13 @@ class KDeepseekV2Model(BaseInjectedModule):
if use_legacy_cache: if use_legacy_cache:
past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_key_values_length = past_key_values.get_usable_length(seq_length) past_key_values_length = past_key_values.get_usable_length(seq_length)
if inputs_embeds is None:
org_device = input_ids.device
# TODO move to embed_tokens's device, not hard code to cpu
input_ids = input_ids.to("cpu")
inputs_embeds = self.embed_tokens(input_ids).to(org_device)
input_ids = input_ids.to(org_device)
if cache_position is None: if cache_position is None:
past_seen_tokens = ( past_seen_tokens = (
...@@ -639,13 +646,6 @@ class KDeepseekV2Model(BaseInjectedModule): ...@@ -639,13 +646,6 @@ class KDeepseekV2Model(BaseInjectedModule):
if position_ids is None: if position_ids is None:
position_ids = cache_position.unsqueeze(0) position_ids = cache_position.unsqueeze(0)
if inputs_embeds is None:
org_device = input_ids.device
# TODO move to embed_tokens's device, not hard code to cpu
input_ids = input_ids.to("cpu")
inputs_embeds = self.embed_tokens(input_ids).to(org_device)
input_ids = input_ids.to(org_device)
if per_layer_prefill_flag: if per_layer_prefill_flag:
causal_mask = None causal_mask = None
else: else:
...@@ -717,6 +717,8 @@ class KDeepseekV2Model(BaseInjectedModule): ...@@ -717,6 +717,8 @@ class KDeepseekV2Model(BaseInjectedModule):
self.load_layer_to(decoder_layer, InferenceState.PREFILL) self.load_layer_to(decoder_layer, InferenceState.PREFILL)
torch.cuda.empty_cache() torch.cuda.empty_cache()
t4 = time.time() t4 = time.time()
# with open("log.txt", "a") as f:
# f.write(f"@@@@@@@@@@@@@@@@@layer {i}@@@@@@@@@@@@@@@@@@@@ \n")
layer_outputs = decoder_layer( layer_outputs = decoder_layer(
hidden_states, hidden_states,
attention_mask=causal_mask, attention_mask=causal_mask,
...@@ -739,13 +741,17 @@ class KDeepseekV2Model(BaseInjectedModule): ...@@ -739,13 +741,17 @@ class KDeepseekV2Model(BaseInjectedModule):
hidden_states = layer_outputs[0] hidden_states = layer_outputs[0]
# @@@@@@@ TODO open this notes, tmp close to fit deepseekv3 # @@@@@@@ TODO open this notes, tmp close to fit deepseekv3
# if use_cache: if use_cache:
# next_decoder_cache = layer_outputs[2 if output_attentions else 1] next_decoder_cache = layer_outputs[2 if output_attentions else 1]
if output_attentions: if output_attentions:
all_self_attns += (layer_outputs[1],) all_self_attns += (layer_outputs[1],)
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
# with open("log.txt", "a") as f:
# f.write(f"@@@After layers\n")
# f.write(f"hidden_states={hidden_states}\n")
# f.write(f"hidden_states.shape={hidden_states.shape}\n")
if per_layer_prefill_flag: if per_layer_prefill_flag:
t6 = time.time() t6 = time.time()
......
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\." name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace: replace:
class: ktransformers.operators.RoPE.DeepSeekV3YarnRotaryEmbedding class: ktransformers.operators.RoPE.RotaryEmbeddingV3
kwargs: kwargs:
generate_device: "cuda:0" generate_device: "cuda:0"
prefill_device: "cuda:0" prefill_device: "cuda:0"
...@@ -18,7 +18,7 @@ ...@@ -18,7 +18,7 @@
name: "^model\\.layers\\.([3456][0-9])\\." name: "^model\\.layers\\.([3456][0-9])\\."
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding class: ktransformers.models.modeling_deepseek_v3.DeepseekV3RotaryEmbedding
replace: replace:
class: ktransformers.operators.RoPE.DeepSeekV3YarnRotaryEmbedding class: ktransformers.operators.RoPE.RotaryEmbeddingV3
kwargs: kwargs:
generate_device: "cuda:1" generate_device: "cuda:1"
prefill_device: "cuda:1" prefill_device: "cuda:1"
...@@ -64,7 +64,7 @@ ...@@ -64,7 +64,7 @@
- match: - match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$" name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3TopkRouter class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGate class: ktransformers.operators.gate.KMoEGate
kwargs: kwargs:
...@@ -72,7 +72,7 @@ ...@@ -72,7 +72,7 @@
prefill_device: "cuda:0" prefill_device: "cuda:0"
- match: - match:
name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$" name: "^model\\.layers\\.([3456][0-9])\\.mlp\\.gate$"
class: ktransformers.models.modeling_deepseek_v3.DeepseekV3TopkRouter class: ktransformers.models.modeling_deepseek_v3.MoEGate
replace: replace:
class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function class: ktransformers.operators.gate.KMoEGate # mlp module with custom forward function
kwargs: kwargs:
...@@ -106,14 +106,14 @@ ...@@ -106,14 +106,14 @@
- match: - match:
name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$" name: "^model\\.layers\\.(0|[1-9]|[12][0-9])\\.self_attn$"
replace: replace:
class: ktransformers.operators.attention.KDeepseekV3Attention # optimized MLA implementation class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs: kwargs:
generate_device: "cuda:0" generate_device: "cuda:0"
prefill_device: "cuda:0" prefill_device: "cuda:0"
- match: - match:
name: "^model\\.layers\\.([3456][0-9])\\.self_attn$" name: "^model\\.layers\\.([3456][0-9])\\.self_attn$"
replace: replace:
class: ktransformers.operators.attention.KDeepseekV3Attention # optimized MLA implementation class: ktransformers.operators.attention.KDeepseekV2Attention # optimized MLA implementation
kwargs: kwargs:
generate_device: "cuda:1" generate_device: "cuda:1"
prefill_device: "cuda:1" prefill_device: "cuda:1"
......
...@@ -24,7 +24,7 @@ class KTransformersInterface(TransformersInterface): ...@@ -24,7 +24,7 @@ class KTransformersInterface(TransformersInterface):
self.args = args self.args = args
torch.set_default_dtype(torch.bfloat16) torch.set_default_dtype(torch.bfloat16)
torch.set_grad_enabled(False) torch.set_grad_enabled(False)
self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device) self.tokenizer = AutoTokenizer.from_pretrained(args.model_dir, device=args.device, trust_remote_code=True)
config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True) config = AutoConfig.from_pretrained(args.model_dir, trust_remote_code=True)
if config.architectures[0] == "Qwen2MoeForCausalLM": if config.architectures[0] == "Qwen2MoeForCausalLM":
config._attn_implementation = "flash_attention_2" config._attn_implementation = "flash_attention_2"
...@@ -99,7 +99,7 @@ class KTransformersInterface(TransformersInterface): ...@@ -99,7 +99,7 @@ class KTransformersInterface(TransformersInterface):
if self.use_static_cache: if self.use_static_cache:
mask = torch.ones((1, self.seq_length)).to(torch_device) mask = torch.ones((1, self.seq_length)).to(torch_device)
logits = self.model( logits = self.model(
self.current_ids, self.current_ids.to(torch_device),
cache_position=self.active_cache_position, cache_position=self.active_cache_position,
past_key_values=self.cache, past_key_values=self.cache,
attention_mask=mask, attention_mask=mask,
......
...@@ -198,7 +198,7 @@ class TransformersInterface(BackendInterfaceBase): ...@@ -198,7 +198,7 @@ class TransformersInterface(BackendInterfaceBase):
return self.streamer.put(new_tokens) return self.streamer.put(new_tokens)
def logits_to_token(self, logits: torch.Tensor): def logits_to_token(self, logits: torch.Tensor):
logits = logits / self.args.temperature logits = logits / self.args.temperature if self.args.temperature!=0 else logits
for token_idx in self.ever_generated_ids: for token_idx in self.ever_generated_ids:
if logits[token_idx] < 0: if logits[token_idx] < 0:
...@@ -318,7 +318,9 @@ class TransformersInterface(BackendInterfaceBase): ...@@ -318,7 +318,9 @@ class TransformersInterface(BackendInterfaceBase):
if isinstance(local_messages, List): if isinstance(local_messages, List):
input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages) input_ids = self.format_and_tokenize_input_ids(thread_id, local_messages)
elif isinstance(local_messages, str): elif isinstance(local_messages, str):
#local_messages = local_messages[0]['content']
input_ids = self.tokenize_prompt(local_messages) input_ids = self.tokenize_prompt(local_messages)
#input_ids = torch.tensor([[6366]], device=input_ids.device)
else: else:
raise ValueError("local_messages should be List or str") raise ValueError("local_messages should be List or str")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment