attention.py 3.74 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING

from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available

luopl's avatar
luopl committed
19
from ...extras import logging
chenych's avatar
chenych committed
20
from ...extras.constants import AttentionFunction
chenych's avatar
chenych committed
21
22
23
24
25
26
27
28


if TYPE_CHECKING:
    from transformers import PretrainedConfig

    from ...hparams import ModelArguments


luopl's avatar
luopl committed
29
logger = logging.get_logger(__name__)
chenych's avatar
chenych committed
30
31


chenych's avatar
chenych committed
32
33
def configure_attn_implementation(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
    if getattr(config, "model_type", None) == "gemma2":
chenych's avatar
chenych committed
34
        if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
chenych's avatar
chenych committed
35
            if is_flash_attn_2_available():
chenych's avatar
chenych committed
36
37
38
                if model_args.flash_attn != AttentionFunction.FA2:
                    logger.warning_rank0("Gemma 2 should use flash attention 2, change `flash_attn` to fa2.")
                    model_args.flash_attn = AttentionFunction.FA2
chenych's avatar
chenych committed
39
            else:
luopl's avatar
luopl committed
40
                logger.warning_rank0("FlashAttention-2 is not installed, use eager attention.")
chenych's avatar
chenych committed
41
42
                model_args.flash_attn = AttentionFunction.DISABLED
        elif model_args.flash_attn == AttentionFunction.SDPA:
luopl's avatar
luopl committed
43
44
45
            logger.warning_rank0(
                "Gemma-2 should use soft-capping attention, while the SDPA attention does not support it."
            )
chenych's avatar
chenych committed
46

chenych's avatar
chenych committed
47
    if model_args.flash_attn == AttentionFunction.AUTO:
chenych's avatar
chenych committed
48
49
        return

chenych's avatar
chenych committed
50
    elif model_args.flash_attn == AttentionFunction.DISABLED:
chenych's avatar
chenych committed
51
52
        requested_attn_implementation = "eager"

chenych's avatar
chenych committed
53
    elif model_args.flash_attn == AttentionFunction.SDPA:
chenych's avatar
chenych committed
54
        if not is_torch_sdpa_available():
luopl's avatar
luopl committed
55
            logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.")
chenych's avatar
chenych committed
56
57
58
            return

        requested_attn_implementation = "sdpa"
chenych's avatar
chenych committed
59
    elif model_args.flash_attn == AttentionFunction.FA2:
chenych's avatar
chenych committed
60
        if not is_flash_attn_2_available():
luopl's avatar
luopl committed
61
            logger.warning_rank0("FlashAttention-2 is not installed.")
chenych's avatar
chenych committed
62
63
64
65
            return

        requested_attn_implementation = "flash_attention_2"
    else:
luopl's avatar
luopl committed
66
        raise NotImplementedError(f"Unknown attention type: {model_args.flash_attn}")
chenych's avatar
chenych committed
67
68
69

    if getattr(config, "model_type", None) == "internlm2":  # special case for custom models
        setattr(config, "attn_implementation", requested_attn_implementation)
chenych's avatar
chenych committed
70
71
72
    elif getattr(config, "model_type", None) == "kimi_vl":
        setattr(config.vision_config, "_attn_implementation", requested_attn_implementation)
        setattr(config.text_config, "_attn_implementation", requested_attn_implementation)
chenych's avatar
chenych committed
73
74
75
76
77
78
79
80
81
82
83
    else:
        setattr(config, "_attn_implementation", requested_attn_implementation)


def print_attn_implementation(config: "PretrainedConfig") -> None:
    if getattr(config, "model_type", None) == "internlm2":  # special case for custom models
        attn_implementation = getattr(config, "attn_implementation", None)
    else:
        attn_implementation = getattr(config, "_attn_implementation", None)

    if attn_implementation == "flash_attention_2":
luopl's avatar
luopl committed
84
        logger.info_rank0("Using FlashAttention-2 for faster training and inference.")
chenych's avatar
chenych committed
85
    elif attn_implementation == "sdpa":
luopl's avatar
luopl committed
86
        logger.info_rank0("Using torch SDPA for faster training and inference.")
chenych's avatar
chenych committed
87
    else:
luopl's avatar
luopl committed
88
        logger.info_rank0("Using vanilla attention implementation.")