Commit c009512a authored by Azure-Tang's avatar Azure-Tang
Browse files

Merge branch 'main' into hip

parents c1f13a69 4f22d726
...@@ -1742,8 +1742,7 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel): ...@@ -1742,8 +1742,7 @@ class DeepseekV2ForCausalLM(DeepseekV2PreTrainedModel):
) )
hidden_states = outputs[0] hidden_states = outputs[0]
logits = self.lm_head(hidden_states) logits = self.lm_head(hidden_states[:,-1:,:]).float()
logits = logits[:,-1,:].unsqueeze(0).float()
loss = None loss = None
if labels is not None: if labels is not None:
......
...@@ -1699,7 +1699,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel): ...@@ -1699,7 +1699,7 @@ class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel):
) )
hidden_states = outputs[0] hidden_states = outputs[0]
logits = self.lm_head(hidden_states.to(self.lm_head.weight.device)) logits = self.lm_head(hidden_states[:,-1:,:])
logits = logits.float() logits = logits.float()
loss = None loss = None
......
...@@ -42,7 +42,7 @@ class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding): ...@@ -42,7 +42,7 @@ class RotaryEmbedding(BaseInjectedModule, DeepseekV2RotaryEmbedding):
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__( BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
) )
self.orig_module.__init__( self.orig_module.__init__(
orig_module.dim, orig_module.max_position_embeddings, orig_module.base orig_module.dim, orig_module.max_position_embeddings, orig_module.base
...@@ -72,7 +72,7 @@ class RotaryEmbeddingV3(BaseInjectedModule): ...@@ -72,7 +72,7 @@ class RotaryEmbeddingV3(BaseInjectedModule):
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__( BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
) )
self.generate_device = generate_device self.generate_device = generate_device
self.prefill_device = prefill_device self.prefill_device = prefill_device
...@@ -122,7 +122,7 @@ class RotaryEmbeddingV2(BaseInjectedModule, LlamaRotaryEmbedding): ...@@ -122,7 +122,7 @@ class RotaryEmbeddingV2(BaseInjectedModule, LlamaRotaryEmbedding):
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__( BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
) )
self.orig_module.__init__( self.orig_module.__init__(
orig_module.dim, orig_module.dim,
...@@ -160,7 +160,7 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding): ...@@ -160,7 +160,7 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__( BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
) )
self.orig_module.__init__( self.orig_module.__init__(
orig_module.dim, orig_module.dim,
...@@ -204,7 +204,7 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding): ...@@ -204,7 +204,7 @@ class YarnRotaryEmbedding(BaseInjectedModule, DeepseekV2YarnRotaryEmbedding):
# **kwargs, # **kwargs,
# ): # ):
# BaseInjectedModule.__init__( # BaseInjectedModule.__init__(
# self, key, gguf_loader, config, orig_module, generate_device, **kwargs # self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
# ) # )
# self.generate_device = generate_device # self.generate_device = generate_device
# self.prefill_device = prefill_device # self.prefill_device = prefill_device
...@@ -230,7 +230,7 @@ class YarnRotaryEmbeddingV3(BaseInjectedModule): ...@@ -230,7 +230,7 @@ class YarnRotaryEmbeddingV3(BaseInjectedModule):
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__( BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, generate_device, **kwargs self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
) )
self.generate_device = generate_device self.generate_device = generate_device
self.prefill_device = prefill_device self.prefill_device = prefill_device
...@@ -332,11 +332,12 @@ class DynamicNTKScalingRotaryEmbedding( ...@@ -332,11 +332,12 @@ class DynamicNTKScalingRotaryEmbedding(
gguf_loader: GGUFLoader, gguf_loader: GGUFLoader,
config: PretrainedConfig, config: PretrainedConfig,
orig_module: nn.Module, orig_module: nn.Module,
device: str = "cuda", prefill_device: str = "cuda",
generate_device: str = "cuda",
**kwargs, **kwargs,
): ):
BaseInjectedModule.__init__( BaseInjectedModule.__init__(
self, key, gguf_loader, config, orig_module, device, **kwargs self, key, gguf_loader, config, orig_module, prefill_device, generate_device, **kwargs
) )
self.orig_module.__init__( self.orig_module.__init__(
orig_module.dim, orig_module.dim,
......
This diff is collapsed.
...@@ -16,14 +16,17 @@ class BaseInjectedModule(nn.Module): ...@@ -16,14 +16,17 @@ class BaseInjectedModule(nn.Module):
gguf_loader : GGUFLoader, gguf_loader : GGUFLoader,
config: PretrainedConfig, config: PretrainedConfig,
orig_module: nn.Module, orig_module: nn.Module,
device: str = "cuda", prefill_device: str = "cuda",
generate_device: str = "cuda",
**kwargs): **kwargs):
nn.Module.__init__(self) nn.Module.__init__(self)
nn.Module.__setattr__(self, "orig_module", orig_module) nn.Module.__setattr__(self, "orig_module", orig_module)
object.__setattr__(self, "key", key) object.__setattr__(self, "key", key)
object.__setattr__(self, "gguf_loader", gguf_loader) object.__setattr__(self, "gguf_loader", gguf_loader)
object.__setattr__(self, "config", config) object.__setattr__(self, "config", config)
object.__setattr__(self, "device", device) object.__setattr__(self, "prefill_device", prefill_device)
object.__setattr__(self, "generate_device", generate_device)
object.__setattr__(self, "device", generate_device)
def __getattr__(self, name: str) -> Any: def __getattr__(self, name: str) -> Any:
# __getattr__ in nn.Module doesn't call super().__getattribute__ when name is not in nn.Module.__dict__, # __getattr__ in nn.Module doesn't call super().__getattribute__ when name is not in nn.Module.__dict__,
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -56,7 +56,7 @@ from ktransformers.models.modeling_deepseek import ( ...@@ -56,7 +56,7 @@ from ktransformers.models.modeling_deepseek import (
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
from ktransformers.models.configuration_llama import LlamaConfig from ktransformers.models.configuration_llama import LlamaConfig
from ktransformers.operators.base_operator import BaseInjectedModule from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.utils import InferenceState from ktransformers.util.utils import InferenceState, get_compute_capability
from ktransformers.util.custom_gguf import GGUFLoader from ktransformers.util.custom_gguf import GGUFLoader
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from ktransformers.models.modeling_llama import ( from ktransformers.models.modeling_llama import (
...@@ -649,9 +649,14 @@ class KDeepseekV2Model(BaseInjectedModule): ...@@ -649,9 +649,14 @@ class KDeepseekV2Model(BaseInjectedModule):
if per_layer_prefill_flag: if per_layer_prefill_flag:
causal_mask = None causal_mask = None
else: else:
if os.name == 'nt' or get_compute_capability()<8:
print("for Windows or GPU before ampere, use forward_windows")
# only use mask in forward windows or can't flash attn
causal_mask = self._update_causal_mask( causal_mask = self._update_causal_mask(
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
) )
else:
causal_mask = None
# embed positions # embed positions
hidden_states = inputs_embeds hidden_states = inputs_embeds
......
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment