models.py 56.6 KB
Newer Older
chenxl's avatar
chenxl committed
1
2
#!/usr/bin/env python
# coding=utf-8
chenxl's avatar
chenxl committed
3
"""
chenxl's avatar
chenxl committed
4
5
6
7
8
Description  :  
Author       : Azure-Tang
Date         : 2024-07-25 11:25:24
Version      : 1.0.0
LastEditors  : Azure 
chenxl's avatar
chenxl committed
9
LastEditTime : 2024-08-27 07:29:04
chenxl's avatar
chenxl committed
10
Copyright (c) 2024 by KVCache.AI, All Rights Reserved. 
chenxl's avatar
chenxl committed
11
"""
chenxl's avatar
chenxl committed
12
13
14
15
16
17
18
19
20
21

import inspect
import math
from typing import List, Optional, Tuple, Union
import time
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
chenxl's avatar
chenxl committed
22
23
24
25
from ktransformers.operators.dynamic_attention import DynamicScaledDotProductAttention
from ktransformers.server.config.config import Config
import os
import yaml
chenxl's avatar
chenxl committed
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
from transformers.activations import ACT2FN
from transformers.cache_utils import Cache, DynamicCache, StaticCache
from transformers.modeling_attn_mask_utils import (
    AttentionMaskConverter,
)
from transformers.modeling_outputs import (
    MoeCausalLMOutputWithPast,
    MoeModelOutputWithPast,
    SequenceClassifierOutputWithPast,
    TokenClassifierOutput,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import (
    add_start_docstrings,
    add_start_docstrings_to_model_forward,
    is_flash_attn_2_available,
    is_flash_attn_greater_or_equal_2_10,
    logging,
    replace_return_docstrings,
)
chenxl's avatar
chenxl committed
46
47
48
49
50
51
52
53
54
55
from ktransformers.models.modeling_qwen2_moe import (
    Qwen2MoeSparseMoeBlock,
    Qwen2MoeMLP,
    Qwen2MoeDecoderLayer,
)
from ktransformers.models.modeling_deepseek import (
    BaseModelOutputWithPast,
    DeepseekV2DecoderLayer,
    DeepseekV2MoE,
)
chenxl's avatar
chenxl committed
56
from transformers.models.qwen2_moe.configuration_qwen2_moe import Qwen2MoeConfig
chenxl's avatar
chenxl committed
57
from ktransformers.models.configuration_llama import LlamaConfig
chenxl's avatar
chenxl committed
58
59
from ktransformers.operators.base_operator import BaseInjectedModule
from ktransformers.util.utils import InferenceState
chenxl's avatar
chenxl committed
60
61
from ktransformers.util.custom_gguf import GGUFLoader
from transformers.configuration_utils import PretrainedConfig
chenxl's avatar
chenxl committed
62
63
64
65
66
from ktransformers.models.modeling_llama import (
    LlamaDecoderLayer,
    LlamaRMSNorm,
    LlamaRotaryEmbedding,
)
chenxl's avatar
chenxl committed
67
68
69
70
71

if is_flash_attn_2_available():
    from flash_attn import flash_attn_func, flash_attn_varlen_func
    from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input  # noqa

chenxl's avatar
chenxl committed
72
73
74
    _flash_supports_window_size = "window_size" in list(
        inspect.signature(flash_attn_func).parameters
    )
chenxl's avatar
chenxl committed
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172

logger = logging.get_logger(__name__)

_CHECKPOINT_FOR_DOC = "Qwen/Qwen1.5-MoE-A2.7B"
_CONFIG_FOR_DOC = "Qwen2MoeConfig"

QWEN2MOE_START_DOCSTRING = r"""
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`Qwen2MoeConfig`]):
            Model configuration class with all the parameters of the model. Initializing with a config file does not
            load the weights associated with the model, only the configuration. Check out the
            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

QWEN2MOE_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
            it.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
            `past_key_values`).

            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
            information on the default strategy.

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.n_positions - 1]`.

            [What are position IDs?](../glossary#position-ids)
        past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
            Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
            blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
            returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.

            Two formats are allowed:
            - a [`~cache_utils.Cache`] instance;
            - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
            shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
            cache format.

            The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
            legacy cache format will be returned.

            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
            of shape `(batch_size, sequence_length)`.
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`).
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        output_router_logits (`bool`, *optional*):
            Whether or not to return the logits of all the routers. They are useful for computing the router loss, and
            should not be returned during inference.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
            Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
            this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
            the complete sequence length.
"""

chenxl's avatar
chenxl committed
173

chenxl's avatar
chenxl committed
174
175
176
177
@add_start_docstrings(
    "The bare Qwen2MoE Model outputting raw hidden-states without any specific head on top.",
    QWEN2MOE_START_DOCSTRING,
)
178
class KQwen2MoeModel(BaseInjectedModule):
chenxl's avatar
chenxl committed
179
180
181
182
183
184
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`Qwen2MoeDecoderLayer`]

    Args:
        config: Qwen2MoeConfig
    """
chenxl's avatar
chenxl committed
185

chenxl's avatar
chenxl committed
186
187
188
    def __init__(
        self,
        key: str,
chenxl's avatar
chenxl committed
189
        gguf_loader: GGUFLoader,
chenxl's avatar
chenxl committed
190
191
192
        config: PretrainedConfig,
        orig_module: nn.Module,
        device: str = "cuda",
chenxl's avatar
chenxl committed
193
        per_layer_prefill_intput_threshold: int = 30000,  # if None, no per-layer prefill
chenxl's avatar
chenxl committed
194
        transfer_map: dict = None,
chenxl's avatar
chenxl committed
195
196
        **kwargs,
    ):
chenxl's avatar
chenxl committed
197
198
199
        BaseInjectedModule.__init__(
            self, key, gguf_loader, config, orig_module, device, **kwargs
        )
chenxl's avatar
chenxl committed
200
        self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold
chenxl's avatar
chenxl committed
201
202
        self.transfer_map = transfer_map
        self.stream_device_map = dict()
chenxl's avatar
chenxl committed
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217

    @add_start_docstrings_to_model_forward(QWEN2MOE_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        output_router_logits: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
chenxl's avatar
chenxl committed
218
219
220
        per_layer_prefill_intput_threshold: (
            int | None
        ) = None,  # if None or 0, close per-layer prefill
chenxl's avatar
chenxl committed
221
222
223
    ) -> Union[Tuple, MoeModelOutputWithPast]:
        # print(f'Total length of input_ids: {input_ids.size(1)}, {input_ids.size()}')

chenxl's avatar
chenxl committed
224
225
        if per_layer_prefill_intput_threshold is None:
            per_layer_prefill_intput_threshold = self.per_layer_prefill_intput_threshold
chenxl's avatar
chenxl committed
226
        per_layer_prefill_flag = False
chenxl's avatar
chenxl committed
227
228
229
230
231
232
233
        seq_lenth = (
            inputs_embeds.size(1) if inputs_embeds is not None else input_ids.size(1)
        )
        if (
            per_layer_prefill_intput_threshold
            and per_layer_prefill_intput_threshold < seq_lenth
        ):
chenxl's avatar
chenxl committed
234
235
236
237
238
            per_layer_prefill_flag = True
            for layer in self.layers:
                self.load_layer_to(layer, InferenceState.UNLOAD)
        else:
            pass
chenxl's avatar
chenxl committed
239
240
241
242
243
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
chenxl's avatar
chenxl committed
244
        output_router_logits = (
chenxl's avatar
chenxl committed
245
246
247
            output_router_logits
            if output_router_logits is not None
            else self.config.output_router_logits
chenxl's avatar
chenxl committed
248
249
        )
        output_hidden_states = (
chenxl's avatar
chenxl committed
250
251
252
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
chenxl's avatar
chenxl committed
253
254
255
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

chenxl's avatar
chenxl committed
256
257
258
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )
chenxl's avatar
chenxl committed
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
            )

        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
                )
                use_cache = False

        use_legacy_cache = False
        if use_cache and not isinstance(past_key_values, Cache):
            use_legacy_cache = True
            past_key_values = DynamicCache.from_legacy_cache(past_key_values)
            logger.warning_once(
                "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
                "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
            )

        if inputs_embeds is None:
            input_ids = input_ids.to("cpu")
            inputs_embeds = self.embed_tokens(input_ids)
            inputs_embeds = inputs_embeds.to("cuda")

        if cache_position is None:
chenxl's avatar
chenxl committed
287
288
289
            past_seen_tokens = (
                past_key_values.get_seq_length() if past_key_values is not None else 0
            )
chenxl's avatar
chenxl committed
290
            cache_position = torch.arange(
chenxl's avatar
chenxl committed
291
292
293
                past_seen_tokens,
                past_seen_tokens + inputs_embeds.shape[1],
                device=inputs_embeds.device,
chenxl's avatar
chenxl committed
294
295
296
297
298
            )
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = self._update_causal_mask(
chenxl's avatar
chenxl committed
299
300
301
302
303
            attention_mask,
            inputs_embeds,
            cache_position,
            past_key_values,
            output_attentions,
chenxl's avatar
chenxl committed
304
305
306
307
308
309
310
311
312
313
        )

        hidden_states = inputs_embeds

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        all_router_logits = () if output_router_logits else None
        next_decoder_cache = None

chenxl's avatar
chenxl committed
314
        for i, decoder_layer in enumerate(self.layers):
chenxl's avatar
chenxl committed
315
            if self.transfer_map is not None and i in self.transfer_map:
chenxl's avatar
chenxl committed
316
317
318
319
320
321
322
                prev_stream = torch.cuda.current_stream()
                cur_device = self.transfer_map[i]
                if cur_device not in self.stream_device_map:
                    self.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)
                torch.cuda.set_device(cur_device)
                self.stream_device_map[cur_device].wait_stream(prev_stream)
                torch.cuda.set_stream(self.stream_device_map[cur_device])
chenxl's avatar
chenxl committed
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
                hidden_states = hidden_states.to(
                    self.transfer_map[i], non_blocking=True
                )
                causal_mask = (
                    causal_mask.to(self.transfer_map[i], non_blocking=True)
                    if causal_mask is not None
                    else None
                )
                position_ids = (
                    position_ids.to(self.transfer_map[i], non_blocking=True)
                    if position_ids is not None
                    else None
                )
                cache_position = (
                    cache_position.to(self.transfer_map[i], non_blocking=True)
                    if cache_position is not None
                    else None
                )

chenxl's avatar
chenxl committed
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    causal_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    output_router_logits,
                    use_cache,
                    cache_position,
                )
            else:
                if per_layer_prefill_flag:
                    # print(f"to gpu")
                    self.load_layer_to(decoder_layer, InferenceState.PREFILL)
                    torch.cuda.empty_cache()
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    output_router_logits=output_router_logits,
                    use_cache=use_cache,
                    cache_position=cache_position,
                )
                if per_layer_prefill_flag:
                    # print(f"to cpu")
                    self.load_layer_to(decoder_layer, InferenceState.UNLOAD)
                    torch.cuda.empty_cache()
            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

            if output_router_logits and layer_outputs[-1] is not None:
                all_router_logits += (layer_outputs[-1],)

        hidden_states = self.norm(hidden_states)

        if per_layer_prefill_flag:
            per_layer_prefill_flag = False
            for layer in self.layers:
                self.load_layer_to(layer, InferenceState.GENERATE)
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = None
        if use_cache:
chenxl's avatar
chenxl committed
398
399
400
401
402
            next_cache = (
                next_decoder_cache.to_legacy_cache()
                if use_legacy_cache
                else next_decoder_cache
            )
chenxl's avatar
chenxl committed
403
404
405
406

        if not return_dict:
            return tuple(
                v
chenxl's avatar
chenxl committed
407
408
409
410
411
412
413
                for v in [
                    hidden_states,
                    next_cache,
                    all_hidden_states,
                    all_self_attns,
                    all_router_logits,
                ]
chenxl's avatar
chenxl committed
414
415
416
417
418
419
420
421
422
423
                if v is not None
            )
        return MoeModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
            router_logits=all_router_logits,
        )

chenxl's avatar
chenxl committed
424
425
426
427
    def load_layer_to(self, layer: Qwen2MoeDecoderLayer, target: InferenceState):
        assert isinstance(
            layer, Qwen2MoeDecoderLayer
        ), "module should be nn.ModuleList of decoder layers"
chenxl's avatar
chenxl committed
428
429

        # TODO Support restore to original device, not only cuda
chenxl's avatar
chenxl committed
430
        device = "cpu" if target == InferenceState.UNLOAD else "cuda"
chenxl's avatar
chenxl committed
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527

        # attn
        layer.self_attn.q_proj.set_inference_mode(target)
        layer.self_attn.k_proj.set_inference_mode(target)
        layer.self_attn.v_proj.set_inference_mode(target)
        layer.self_attn.o_proj.set_inference_mode(target)
        layer.self_attn.rotary_emb = layer.self_attn.rotary_emb.to(device)

        # mlp
        if isinstance(layer.mlp, Qwen2MoeSparseMoeBlock):
            layer.mlp.gate.set_inference_mode(target)
            layer.mlp.experts.set_inference_mode(target)
            layer.mlp.shared_expert.gate_proj.set_inference_mode(target)
            layer.mlp.shared_expert.up_proj.set_inference_mode(target)
            layer.mlp.shared_expert.down_proj.set_inference_mode(target)
            layer.mlp.shared_expert.act_fn.to(device)
            layer.mlp.shared_expert_gate.to(device)
        else:
            layer.mlp.gate_proj.set_inference_mode(target)
            layer.mlp.up_proj.set_inference_mode(target)
            layer.mlp.down_proj.set_inference_mode(target)
            layer.mlp.act_fn.to(device)
        # layer norm
        layer.input_layernorm.to(device)
        layer.post_attention_layernorm.to(device)


DeepseekV2_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
            it.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
            `past_key_values`).

            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
            information on the default strategy.

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.n_positions - 1]`.

            [What are position IDs?](../glossary#position-ids)
        past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
            Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
            blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
            returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.

            Two formats are allowed:
            - a [`~cache_utils.Cache`] instance;
            - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
            shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
            cache format.

            The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
            legacy cache format will be returned.

            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
            of shape `(batch_size, sequence_length)`.
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`).
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""


528
class KDeepseekV2Model(BaseInjectedModule):
chenxl's avatar
chenxl committed
529
530
531
532
533
534
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeepseekV2DecoderLayer`]

    Args:
        config: DeepseekV2Config
    """
chenxl's avatar
chenxl committed
535

chenxl's avatar
chenxl committed
536
537
538
    def __init__(
        self,
        key: str,
chenxl's avatar
chenxl committed
539
        gguf_loader: GGUFLoader,
chenxl's avatar
chenxl committed
540
541
542
        config: PretrainedConfig,
        orig_module: nn.Module,
        device: str = "cuda",
chenxl's avatar
chenxl committed
543
        per_layer_prefill_intput_threshold: int = 30000,  # if None, no per-layer prefill
chenxl's avatar
chenxl committed
544
        transfer_map: dict = None,
chenxl's avatar
chenxl committed
545
546
        **kwargs,
    ):
chenxl's avatar
chenxl committed
547
548
549
        BaseInjectedModule.__init__(
            self, key, gguf_loader, config, orig_module, device, **kwargs
        )
chenxl's avatar
chenxl committed
550
        self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold
chenxl's avatar
chenxl committed
551
552
        self.transfer_map = transfer_map
        self.stream_device_map = dict()
chenxl's avatar
chenxl committed
553
554
555
556
557
558
559
560
561
562
563
564
565
566

    @add_start_docstrings_to_model_forward(DeepseekV2_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
chenxl's avatar
chenxl committed
567
568
569
        per_layer_prefill_intput_threshold: (
            int | None
        ) = None,  # if None, no per-layer prefill
chenxl's avatar
chenxl committed
570
    ) -> Union[Tuple, BaseModelOutputWithPast]:
chenxl's avatar
chenxl committed
571
572
        if per_layer_prefill_intput_threshold is None:
            per_layer_prefill_intput_threshold = self.per_layer_prefill_intput_threshold
chenxl's avatar
chenxl committed
573
        per_layer_prefill_flag = False
chenxl's avatar
chenxl committed
574
575
576
577
578
579
580
        seq_lenth = (
            inputs_embeds.size(1) if inputs_embeds is not None else input_ids.size(1)
        )
        if (
            per_layer_prefill_intput_threshold
            and per_layer_prefill_intput_threshold < seq_lenth
        ):
chenxl's avatar
chenxl committed
581
582
            per_layer_prefill_flag = True
            for layer in self.layers:
chenxl's avatar
chenxl committed
583
                self.load_layer_to(layer, InferenceState.UNLOAD)
chenxl's avatar
chenxl committed
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
            torch.cuda.empty_cache()
        else:
            pass
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time"
            )
        elif input_ids is not None:
            batch_size, seq_length = input_ids.shape[:2]
        elif inputs_embeds is not None:
            batch_size, seq_length = inputs_embeds.shape[:2]
        else:
            raise ValueError("You have to specify either input_ids or inputs_embeds")

        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`transformers."
                )
                use_cache = False

        past_key_values_length = 0
        if use_cache:
            use_legacy_cache = not isinstance(past_key_values, Cache)
            if use_legacy_cache:
                past_key_values = DynamicCache.from_legacy_cache(past_key_values)
            past_key_values_length = past_key_values.get_usable_length(seq_length)

        if cache_position is None:
chenxl's avatar
chenxl committed
630
631
632
            past_seen_tokens = (
                past_key_values.get_seq_length() if past_key_values is not None else 0
            )
chenxl's avatar
chenxl committed
633
            cache_position = torch.arange(
chenxl's avatar
chenxl committed
634
635
636
                past_seen_tokens,
                past_seen_tokens + inputs_embeds.shape[1],
                device=inputs_embeds.device,
chenxl's avatar
chenxl committed
637
638
639
640
641
642
643
644
645
646
647
            )

        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        if inputs_embeds is None:
            org_device = input_ids.device
            input_ids = input_ids.to("cpu")
            inputs_embeds = self.embed_tokens(input_ids)
            input_ids = input_ids.to(org_device)

chenxl's avatar
chenxl committed
648
649
650
651
652
653
        if per_layer_prefill_flag:
            causal_mask = None
        else:
            causal_mask = self._update_causal_mask(
                attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
            )
chenxl's avatar
chenxl committed
654
655
656
657

        # embed positions
        hidden_states = inputs_embeds
        if per_layer_prefill_flag:
chenxl's avatar
chenxl committed
658
            print(f"Total length of input_ids: {hidden_states.size(1)}")
chenxl's avatar
chenxl committed
659
660
661
662
663
664
665
666
667
668

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None

        t_gpu = 0
        t_cpu = 0
        t_f = 0

chenxl's avatar
chenxl committed
669
        for i, decoder_layer in enumerate(self.layers):
chenxl's avatar
chenxl committed
670
            if self.transfer_map is not None and i in self.transfer_map:
chenxl's avatar
chenxl committed
671
672
673
674
675
676
677
                prev_stream = torch.cuda.current_stream()
                cur_device = self.transfer_map[i]
                if cur_device not in self.stream_device_map:
                    self.stream_device_map[cur_device] = torch.cuda.Stream(cur_device)
                torch.cuda.set_device(cur_device)
                self.stream_device_map[cur_device].wait_stream(prev_stream)
                torch.cuda.set_stream(self.stream_device_map[cur_device])
chenxl's avatar
chenxl committed
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
                hidden_states = hidden_states.to(
                    self.transfer_map[i], non_blocking=True
                )
                causal_mask = (
                    causal_mask.to(self.transfer_map[i], non_blocking=True)
                    if causal_mask is not None
                    else None
                )
                position_ids = (
                    position_ids.to(self.transfer_map[i], non_blocking=True)
                    if position_ids is not None
                    else None
                )
                cache_position = (
                    cache_position.to(self.transfer_map[i], non_blocking=True)
                    if cache_position is not None
                    else None
                )
chenxl's avatar
chenxl committed
696

chenxl's avatar
chenxl committed
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    causal_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    use_cache,
                    cache_position,
                )
            else:
                t3 = time.time()
                if per_layer_prefill_flag:
                    # print(f"to gpu")
                    self.load_layer_to(decoder_layer, InferenceState.PREFILL)
                    torch.cuda.empty_cache()
                t4 = time.time()
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                )
                t5 = time.time()
                if per_layer_prefill_flag:
                    # print(f"to cpu")
chenxl's avatar
chenxl committed
730
                    self.load_layer_to(decoder_layer, InferenceState.UNLOAD)
chenxl's avatar
chenxl committed
731
732
                    torch.cuda.empty_cache()
                t6 = time.time()
chenxl's avatar
chenxl committed
733
734
735
            t_gpu += t4 - t3
            t_cpu += t6 - t5
            t_f += t5 - t4
chenxl's avatar
chenxl committed
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        if per_layer_prefill_flag:
            t6 = time.time()
            # print(f"restore")
            per_layer_prefill_flag = False
            for layer in self.layers:
                self.load_layer_to(layer, InferenceState.GENERATE)
            torch.cuda.empty_cache()
            t7 = time.time()

chenxl's avatar
chenxl committed
756
757
758
            print(
                f"total time: {t7-t3}, \n layer num{len(self.layers)}, gpu time: {t_gpu}, cpu time: {t_cpu}, forward time: {t_f}, restore time: {t7-t6}"
            )
chenxl's avatar
chenxl committed
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = None
        if use_cache:
            next_cache = (
                next_decoder_cache.to_legacy_cache()
                if use_legacy_cache
                else next_decoder_cache
            )
        if not return_dict:
            return tuple(
                v
                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
                if v is not None
            )
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

chenxl's avatar
chenxl committed
784
785
786
787
    def load_layer_to(self, layer: DeepseekV2DecoderLayer, target: InferenceState):
        assert isinstance(
            layer, DeepseekV2DecoderLayer
        ), "module should be nn.ModuleList of decoder layers"
chenxl's avatar
chenxl committed
788
789

        # TODO Support restore to original device, not only cuda
chenxl's avatar
chenxl committed
790
        device = "cpu" if target == InferenceState.UNLOAD else "cuda"
chenxl's avatar
chenxl committed
791
792
793
794

        # TODO Support DFS to auto use {to, set_inference_mode} according to the module type

        # attn
chenxl's avatar
chenxl committed
795
        layer.self_attn.to(device)  #
chenxl's avatar
chenxl committed
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813

        # mlp
        if isinstance(layer.mlp, DeepseekV2MoE):
            layer.mlp.gate.to(device)
            layer.mlp.experts.set_inference_mode(target)
            layer.mlp.shared_experts.gate_proj.set_inference_mode(target)
            layer.mlp.shared_experts.up_proj.set_inference_mode(target)
            layer.mlp.shared_experts.down_proj.set_inference_mode(target)
            layer.mlp.shared_experts.act_fn.to(device)
            # layer.mlp.shared_expert_gate.to(device)
        else:
            layer.mlp.gate_proj.set_inference_mode(target)
            layer.mlp.up_proj.set_inference_mode(target)
            layer.mlp.down_proj.set_inference_mode(target)
            layer.mlp.act_fn.to(device)
        # layer norm
        layer.input_layernorm.to(device)
        layer.post_attention_layernorm.to(device)
chenxl's avatar
chenxl committed
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336


LLAMA_START_DOCSTRING = r"""
    This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
    library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
    etc.)

    This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
    Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
    and behavior.

    Parameters:
        config ([`LlamaConfig`]):
            Model configuration class with all the parameters of the model. Initializing with a config file does not
            load the weights associated with the model, only the configuration. Check out the
            [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""

LLAMA_INPUTS_DOCSTRING = r"""
    Args:
        input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
            Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
            it.

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            [What are input IDs?](../glossary#input-ids)
        attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
            Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:

            - 1 for tokens that are **not masked**,
            - 0 for tokens that are **masked**.

            [What are attention masks?](../glossary#attention-mask)

            Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
            [`PreTrainedTokenizer.__call__`] for details.

            If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
            `past_key_values`).

            If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
            and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
            information on the default strategy.

            - 1 indicates the head is **not masked**,
            - 0 indicates the head is **masked**.
        position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
            Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
            config.n_positions - 1]`.

            [What are position IDs?](../glossary#position-ids)
        past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
            Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
            blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
            returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.

            Two formats are allowed:
            - a [`~cache_utils.Cache`] instance;
            - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
            shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
            cache format.

            The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
            legacy cache format will be returned.

            If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
            have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
            of shape `(batch_size, sequence_length)`.
        inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
            Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
            is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
            model's internal embedding lookup matrix.
        use_cache (`bool`, *optional*):
            If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
            `past_key_values`).
        output_attentions (`bool`, *optional*):
            Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
            tensors for more detail.
        output_hidden_states (`bool`, *optional*):
            Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
            more detail.
        return_dict (`bool`, *optional*):
            Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
        cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
            Indices depicting the position of the input sequence tokens in the sequence. Contrarily to `position_ids`,
            this tensor is not affected by padding. It is used to update the cache in the correct position and to infer
            the complete sequence length.
"""


@add_start_docstrings(
    "The bare LLaMA Model outputting raw hidden-states without any specific head on top.",
    LLAMA_START_DOCSTRING,
)
class LlamaPreTrainedModel(PreTrainedModel):
    config_class = LlamaConfig
    base_model_prefix = "model"
    supports_gradient_checkpointing = True
    _no_split_modules = ["LlamaDecoderLayer"]
    _skip_keys_device_placement = ["past_key_values"]
    _supports_flash_attn_2 = True
    _supports_sdpa = True
    _supports_cache_class = True
    _supports_quantized_cache = True
    _supports_static_cache = True

    def _init_weights(self, module):
        std = self.config.initializer_range
        if isinstance(module, nn.Linear):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.bias is not None:
                module.bias.data.zero_()
        elif isinstance(module, nn.Embedding):
            module.weight.data.normal_(mean=0.0, std=std)
            if module.padding_idx is not None:
                module.weight.data[module.padding_idx].zero_()


class KLlamaModel(BaseInjectedModule):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]

    Args:
        config: LlamaConfig
    """

    dynamic_sdpa = None

    def __init__(
        self,
        key: str,
        gguf_loader: GGUFLoader,
        config: PretrainedConfig,
        orig_module: nn.Module,
        device: str = "cuda",
        per_layer_prefill_intput_threshold: int = 30000,  # if None, no per-layer prefill
        transfer_map: dict = None,
        **kwargs,
    ):

        BaseInjectedModule.__init__(
            self, key, gguf_loader, config, orig_module, device, **kwargs
        )
        self.per_layer_prefill_intput_threshold = per_layer_prefill_intput_threshold
        self.transfer_map = transfer_map
        self.stream_device_map = dict()
        user_path: str = os.path.expanduser('~')
        localstore_path: str = os.path.join(user_path,'.ktransformers')
        config_path: str = os.path.join(localstore_path,Config.CONFIG_FILE_NAME)
        with open(config_path,"r") as file:
            config_yaml = yaml.safe_load(file.read())
            self.long_context_config = config_yaml.get("long_context")
            self.ext_config = config_yaml.get("ext")

        KLlamaModel.dynamic_sdpa = DynamicScaledDotProductAttention(
            max_seq_len=self.long_context_config["max_seq_len"],
            block_size=self.long_context_config["block_size"],
            config=config,
            device=torch.device("cuda"),
            local_windows_len=self.long_context_config["local_windows_len"],
            topk=self.long_context_config["second_select_num"],
            threads_num=self.ext_config["cpu_infer"],
            anchor_type=self.long_context_config["anchor_type"],
            kv_type=self.long_context_config["kv_type"],
            dense_layer_num=self.long_context_config["dense_layer_num"],
            anchor_num=self.long_context_config["anchor_num"],
            preselect_block=self.long_context_config["preselect_block"],
            block_selection_mode=self.long_context_config["head_select_mode"],
            preselect_block_count=self.long_context_config["preselect_block_count"],
            layer_step=self.long_context_config["layer_step"],
            token_step=self.long_context_config["token_step"],
            prefill_chunk_size=self.long_context_config["chunk_size"],
            use_attn_sparsity=False,
        )

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
            )

        if self.gradient_checkpointing and self.training and use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache = False

        return_legacy_cache = False
        if (
            use_cache and not isinstance(past_key_values, Cache) and not self.training
        ):  # kept for BC (non `Cache` `past_key_values` inputs)
            return_legacy_cache = True
            past_key_values = DynamicCache.from_legacy_cache(past_key_values)
            logger.warning_once(
                "We detected that you are passing `past_key_values` as a tuple and this is deprecated and will be removed in v4.43. "
                "Please use an appropriate `Cache` class (https://huggingface.co/docs/transformers/v4.41.3/en/internal/generation_utils#transformers.Cache)"
            )

        if cache_position is None:
            past_seen_tokens = (
                past_key_values.get_seq_length() if past_key_values is not None else 0
            )
            cache_position = torch.arange(
                past_seen_tokens,
                past_seen_tokens + inputs_embeds.shape[1],
                device="cuda",
            )
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = None
        chunck_size = self.long_context_config["chunk_size"]
        cur_idx = 0
        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids.to("cpu"))
        q_len = cache_position.size(0)

        # generate
        if q_len == 1:
            x = inputs_embeds[:, -1:, :]
            position_ids = position_ids[:, -1:]
            return self.forward_chunk(
                x,
                causal_mask,
                position_ids,
                past_key_values,
                output_attentions,
                use_cache,
                cache_position,
                output_hidden_states,
                return_dict,
            )
        elif q_len <= chunck_size:
            inputs_embeds = inputs_embeds.to('cuda')
            output = self.forward_chunk(
                inputs_embeds,
                causal_mask,
                position_ids,
                past_key_values,
                output_attentions,
                use_cache,
                cache_position,
                output_hidden_states,
                return_dict,
            )
            KLlamaModel.dynamic_sdpa.calc_anchor(cache_position[-1] + 1)
            KLlamaModel.dynamic_sdpa.clear_importance(cache_position[-1] + 1)
            return output
        cur_idx = 0
        assert (
            output_attentions == False
        ), "output_attentions is not supported when using chunked attention"
        attn_output = None
        # prefill
        KLlamaModel.dynamic_sdpa.remaining_length = q_len
        while cur_idx < q_len:
            print(f'current prefill length: {cur_idx}')
            chunk_mask = None
            if inputs_embeds.device.type == 'cpu':
                tmp_inputs_embeds = inputs_embeds[:, cur_idx : min(cur_idx + chunck_size, q_len)].to("cuda")
            else:
                tmp_inputs_embeds = inputs_embeds[:, cur_idx : min(cur_idx + chunck_size, q_len)]
            output_with_past = self.forward_chunk(
                tmp_inputs_embeds,
                chunk_mask,
                position_ids[:, cur_idx : min(cur_idx + chunck_size, q_len)],
                past_key_values,
                output_attentions,
                use_cache,
                cache_position[cur_idx : min(cur_idx + chunck_size, q_len)],
            )
            cur_output = output_with_past.last_hidden_state
            KLlamaModel.dynamic_sdpa.remaining_length -= (
                min(cur_idx + chunck_size, q_len) - cur_idx
            )
            cur_idx += chunck_size
            # if attn_output is None:
            attn_output = cur_output
            # else:
            #     attn_output = torch.cat((attn_output, cur_output), dim=-2)

        KLlamaModel.dynamic_sdpa.calc_anchor(cache_position[-1] + 1)
        KLlamaModel.dynamic_sdpa.clear_importance(cache_position[-1] + 1)
        return BaseModelOutputWithPast(last_hidden_state=attn_output)

    def forward_chunk(
        self,
        inputs_embeds,
        causal_mask,
        position_ids,
        past_key_values,
        output_attentions,
        use_cache,
        cache_position,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):

        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        return_legacy_cache = False
        if use_cache and not isinstance(
            past_key_values, Cache
        ):  # kept for BC (non `Cache` `past_key_values` inputs)
            return_legacy_cache = True
            past_key_values = DynamicCache.from_legacy_cache(past_key_values)
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        hidden_states = inputs_embeds

        # create position embeddings to be shared across the decoder layers
        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None
        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None

        for decoder_layer in self.layers:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    causal_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    use_cache,
                    cache_position,
                    position_embeddings,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                    position_embeddings=position_embeddings,
                )

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None
        if return_legacy_cache:
            next_cache = next_cache.to_legacy_cache()

        if not return_dict:
            return tuple(
                v
                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
                if v is not None
            )
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

    def _update_causal_mask(
        self,
        attention_mask: torch.Tensor,
        input_tensor: torch.Tensor,
        cache_position: torch.Tensor,
        past_key_values: Cache,
        output_attentions: bool,
    ):
        # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static
        # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes.
        # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using
        # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114

        if self.config._attn_implementation == "flash_attention_2":
            if attention_mask is not None and 0.0 in attention_mask:
                return attention_mask
            return None

        # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
        # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
        # to infer the attention mask.
        past_seen_tokens = (
            past_key_values.get_seq_length() if past_key_values is not None else 0
        )
        using_static_cache = isinstance(past_key_values, StaticCache)

        # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
        if (
            self.config._attn_implementation == "sdpa"
            and not using_static_cache
            and not output_attentions
        ):
            if AttentionMaskConverter._ignore_causal_mask_sdpa(
                attention_mask,
                inputs_embeds=input_tensor,
                past_key_values_length=past_seen_tokens,
                is_training=self.training,
            ):
                return None

        dtype, device = input_tensor.dtype, input_tensor.device
        min_dtype = torch.finfo(dtype).min
        sequence_length = input_tensor.shape[1]
        if using_static_cache:
            target_length = past_key_values.get_max_length()
        else:
            target_length = (
                attention_mask.shape[-1]
                if isinstance(attention_mask, torch.Tensor)
                else past_seen_tokens + sequence_length + 1
            )

        if attention_mask is not None and attention_mask.dim() == 4:
            # in this case we assume that the mask comes already in inverted form and requires no inversion or slicing
            if attention_mask.max() != 0:
                raise ValueError(
                    "Custom 4D attention mask should be passed in inverted form with max==0`"
                )
            causal_mask = attention_mask
        else:
            causal_mask = torch.full(
                (sequence_length, target_length),
                fill_value=min_dtype,
                dtype=dtype,
                device=device,
            )
            if sequence_length != 1:
                causal_mask = torch.triu(causal_mask, diagonal=1)
            causal_mask *= torch.arange(
                target_length, device=device
            ) > cache_position.reshape(-1, 1)
            causal_mask = causal_mask[None, None, :, :].expand(
                input_tensor.shape[0], 1, -1, -1
            )
            if attention_mask is not None:
                causal_mask = (
                    causal_mask.clone()
                )  # copy to contiguous memory for in-place edit
                mask_length = attention_mask.shape[-1]
                padding_mask = (
                    causal_mask[:, :, :, :mask_length]
                    + attention_mask[:, None, None, :]
                )
                padding_mask = padding_mask == 0
                causal_mask[:, :, :, :mask_length] = causal_mask[
                    :, :, :, :mask_length
                ].masked_fill(padding_mask, min_dtype)
        if (
            self.config._attn_implementation == "sdpa"
            and attention_mask is not None
            and attention_mask.device.type == "cuda"
            and not output_attentions
        ):
            # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
            # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
            # Details: https://github.com/pytorch/pytorch/issues/110213
            causal_mask = AttentionMaskConverter._unmask_unattended(
                causal_mask, min_dtype
            )

        return causal_mask