Unverified Commit 8025def1 authored by Atream's avatar Atream Committed by GitHub
Browse files

Merge pull request #1246 from aubreyli/GenerationMixin

modeling_deepseek_v3: fix GenerationMixin warning
parents 900a7f7c def1ec76
...@@ -30,6 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss ...@@ -30,6 +30,7 @@ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.generation import GenerationMixin
from transformers.modeling_attn_mask_utils import ( from transformers.modeling_attn_mask_utils import (
AttentionMaskConverter, AttentionMaskConverter,
_prepare_4d_attention_mask, _prepare_4d_attention_mask,
...@@ -1598,7 +1599,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel): ...@@ -1598,7 +1599,7 @@ class DeepseekV3Model(DeepseekV3PreTrainedModel):
return causal_mask return causal_mask
class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel): class DeepseekV3ForCausalLM(DeepseekV3PreTrainedModel, GenerationMixin):
_tied_weights_keys = ["lm_head.weight"] _tied_weights_keys = ["lm_head.weight"]
def __init__(self, config): def __init__(self, config):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment