Commit c009512a authored by Azure-Tang's avatar Azure-Tang
Browse files

Merge branch 'main' into hip

parents c1f13a69 4f22d726
......@@ -1742,8 +1742,7 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits[:,-1,:].unsqueeze(0).float()
logits = self.lm_head(hidden_states[:,-1:,:]).float()
loss = None
if labels is not None:
......
......@@ -1699,7 +1699,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
)
hidden_states = outputs[0]
logits = self.lm_head(hidden_states.to(self.lm_head.weight.device))
logits = self.lm_head(hidden_states[:,-1:,:])
logits = logits.float()
loss = None
......
......@@ -42,7 +42,7 @@ class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
)
self.orig_module.__init__(
orig_module.dim, orig_module.max_position_embeddings, orig_module.base
......@@ -72,7 +72,7 @@ class RotaryEmbeddingV3(BaseInjectedModule):
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
)
self.generate_device = generate_device
self.prefill_device = prefill_device
......@@ -122,7 +122,7 @@ class RotaryEmbeddingV2(BaseInjectedModule, LlamaRotaryEmbedding):
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
)
self.orig_module.__init__(
orig_module.dim,
......@@ -160,7 +160,7 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
)
self.orig_module.__init__(
orig_module.dim,
......@@ -204,7 +204,7 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
# **kwargs,
# ):
# BaseInjectedModule.__init__(
# self, key, gguf_loader, config, orig_module, generate_device, **kwargs
# self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
# )
# self.generate_device = generate_device
# self.prefill_device = prefill_device
......@@ -230,7 +230,7 @@ class YarnRotaryEmbeddingV3(BaseInjectedModule):
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
)
self.generate_device = generate_device
self.prefill_device = prefill_device
......@@ -332,11 +332,12 @@ class DynamicNTKScalingRotaryEmbedding(
gguf_loader: GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
device: str = "cuda",
prefill_device: str = "cuda",
generate_device: str = "cuda",
**kwargs,
):
BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, device, **kwargs
self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
)
self.orig_module.__init__(
orig_module.dim,
......
This diff is collapsed.
......@@ -16,14 +16,17 @@ class BaseInjectedModule(nn.Module):
gguf_loader : GGUFLoader,
config: PretrainedConfig,
orig_module: nn.Module,
device: str = "cuda",
prefill_device: str = "cuda",
generate_device: str = "cuda",
**kwargs):
nn.Module.__init__(self)
nn.Module.__setattr__(self, "orig_module", orig_module)
object.__setattr__(self, "key", key)
object.__setattr__(self, "gguf_loader", gguf_loader)
object.__setattr__(self, "config", config)
object.__setattr__(self, "device", device)
object.__setattr__(self, "prefill_device", prefill_device)
object.__setattr__(self, "generate_device", generate_device)
object.__setattr__(self, "device", generate_device)
def __getattr__(self, name: str) -> Any:
# __getattr__ in nn.Module doesn't call super().__getattribute__ when name is not in nn.Module.__dict__,
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -56,7 +56,7 @@ from ktransformers.models.modeling_deepseek import (
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from ktransformers.models.configuration_llama import LlamaConfig
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.utils import InferenceState
from ktransformers.util.utils import InferenceState, get_compute_capability
from ktransformers.util.custom_gguf import GGUFLoader
from transformers.configuration_utils import PretrainedConfig
from ktransformers.models.modeling_llama import (
......@@ -649,9 +649,14 @@ class KDeepseekV2Model(BaseInjectedModule):
if per_layer_prefill_flag:
causal_mask = None
else:
if os.name == 'nt' or get_compute_capability()<8:
print("for Windows or GPU before ampere, use forward_windows")
# only use mask in forward windows or can't flash attn
causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
)
else:
causal_mask = None
# embed positions
hidden_states = inputs_embeds
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment