attention.py 3.79 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available

luopl's avatar
luopl committed
19
from ...extras import logging
chenych's avatar
chenych committed
20
from ...extras.constants import AttentionFunction
chenych's avatar
chenych committed
21
22
23
24
25
26
27
28


if TYPE_CHECKING:
    from transformers import PretrainedConfig

    from ...hparams import ModelArguments


luopl's avatar
luopl committed
29
logger = logging.get_logger(__name__)
chenych's avatar
chenych committed
30
31
32
33
34
35


def configure_attn_implementation(
    config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
) -> None:
    if getattr(config, "model_type", None) == "gemma2" and is_trainable:
chenych's avatar
chenych committed
36
        if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
chenych's avatar
chenych committed
37
            if is_flash_attn_2_available():
chenych's avatar
chenych committed
38
39
40
                if model_args.flash_attn != AttentionFunction.FA2:
                    logger.warning_rank0("Gemma 2 should use flash attention 2, change `flash_attn` to fa2.")
                    model_args.flash_attn = AttentionFunction.FA2
chenych's avatar
chenych committed
41
            else:
luopl's avatar
luopl committed
42
                logger.warning_rank0("FlashAttention-2 is not installed, use eager attention.")
chenych's avatar
chenych committed
43
44
                model_args.flash_attn = AttentionFunction.DISABLED
        elif model_args.flash_attn == AttentionFunction.SDPA:
luopl's avatar
luopl committed
45
46
47
            logger.warning_rank0(
                "Gemma-2 should use soft-capping attention, while the SDPA attention does not support it."
            )
chenych's avatar
chenych committed
48

chenych's avatar
chenych committed
49
    if model_args.flash_attn == AttentionFunction.AUTO:
chenych's avatar
chenych committed
50
51
        return

chenych's avatar
chenych committed
52
    elif model_args.flash_attn == AttentionFunction.DISABLED:
chenych's avatar
chenych committed
53
54
        requested_attn_implementation = "eager"

chenych's avatar
chenych committed
55
    elif model_args.flash_attn == AttentionFunction.SDPA:
chenych's avatar
chenych committed
56
        if not is_torch_sdpa_available():
luopl's avatar
luopl committed
57
            logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.")
chenych's avatar
chenych committed
58
59
60
            return

        requested_attn_implementation = "sdpa"
chenych's avatar
chenych committed
61
    elif model_args.flash_attn == AttentionFunction.FA2:
chenych's avatar
chenych committed
62
        if not is_flash_attn_2_available():
luopl's avatar
luopl committed
63
            logger.warning_rank0("FlashAttention-2 is not installed.")
chenych's avatar
chenych committed
64
65
66
67
            return

        requested_attn_implementation = "flash_attention_2"
    else:
luopl's avatar
luopl committed
68
        raise NotImplementedError(f"Unknown attention type: {model_args.flash_attn}")
chenych's avatar
chenych committed
69
70
71

    if getattr(config, "model_type", None) == "internlm2":  # special case for custom models
        setattr(config, "attn_implementation", requested_attn_implementation)
chenych's avatar
chenych committed
72
73
74
    elif getattr(config, "model_type", None) == "kimi_vl":
        setattr(config.vision_config, "_attn_implementation", requested_attn_implementation)
        setattr(config.text_config, "_attn_implementation", requested_attn_implementation)
chenych's avatar
chenych committed
75
76
77
78
79
80
81
82
83
84
85
    else:
        setattr(config, "_attn_implementation", requested_attn_implementation)


def print_attn_implementation(config: "PretrainedConfig") -> None:
    if getattr(config, "model_type", None) == "internlm2":  # special case for custom models
        attn_implementation = getattr(config, "attn_implementation", None)
    else:
        attn_implementation = getattr(config, "_attn_implementation", None)

    if attn_implementation == "flash_attention_2":
luopl's avatar
luopl committed
86
        logger.info_rank0("Using FlashAttention-2 for faster training and inference.")
chenych's avatar
chenych committed
87
    elif attn_implementation == "sdpa":
luopl's avatar
luopl committed
88
        logger.info_rank0("Using torch SDPA for faster training and inference.")
chenych's avatar
chenych committed
89
    else:
luopl's avatar
luopl committed
90
        logger.info_rank0("Using vanilla attention implementation.")