rope.py 3.3 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 LMSYS and the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# This code is inspired by the LMSYS's FastChat library.
# https://github.com/lm-sys/FastChat/blob/v0.2.30/fastchat/train/train.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import TYPE_CHECKING

luopl's avatar
luopl committed
22
from ...extras import logging
chenych's avatar
chenych committed
23
from ...extras.constants import RopeScaling
chenych's avatar
chenych committed
24
25
26
27
28
29
30
31


if TYPE_CHECKING:
    from transformers import PretrainedConfig

    from ...hparams import ModelArguments


luopl's avatar
luopl committed
32
logger = logging.get_logger(__name__)
chenych's avatar
chenych committed
33
34


chenych's avatar
chenych committed
35
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments") -> None:
chenych's avatar
chenych committed
36
37
38
39
    if model_args.rope_scaling is None:
        return

    if not hasattr(config, "rope_scaling"):
luopl's avatar
luopl committed
40
        logger.warning_rank0("Current model does not support RoPE scaling.")
chenych's avatar
chenych committed
41
42
        return

shihm's avatar
uodata  
shihm committed
43
44
45
46
    rope_scaling = getattr(config, "rope_scaling", None)
    if isinstance(rope_scaling, dict) and "original_max_position_embeddings" in rope_scaling:
        old_max_length = rope_scaling["original_max_position_embeddings"]
    elif hasattr(config, "max_position_embeddings"):
chenych's avatar
chenych committed
47
48
49
50
51
52
53
54
55
56
57
        old_max_length = getattr(config, "max_position_embeddings", None)
    else:
        logger.warning_rank0("Cannot find the max position embeddings in the config.")
        return

    if model_args.model_max_length is not None:  # training
        if model_args.model_max_length <= old_max_length:
            logger.warning_rank0("Input length is smaller than max length. Disabling rope scaling.")
            return

        if model_args.rope_scaling == RopeScaling.DYNAMIC:
luopl's avatar
luopl committed
58
            logger.warning_rank0(
chenych's avatar
chenych committed
59
60
61
62
                "Dynamic NTK scaling may not work well with fine-tuning. "
                "See: https://github.com/huggingface/transformers/pull/24653"
            )

chenych's avatar
chenych committed
63
64
65
        rope_factor = float(math.ceil(model_args.model_max_length / old_max_length))
    else:  # inference
        rope_factor = 2.0
chenych's avatar
chenych committed
66

chenych's avatar
chenych committed
67
68
69
70
71
72
73
74
75
76
77
78
79
    rope_kwargs = {
        "rope_type": getattr(model_args.rope_scaling, "value", model_args.rope_scaling),  # handle enum
        "factor": rope_factor,
    }
    setattr(config, "max_position_embeddings", old_max_length * rope_factor)
    logger.info_rank0(f"Enlarge max model length from {old_max_length} to {old_max_length * rope_factor}.")

    if model_args.rope_scaling in [RopeScaling.DYNAMIC, RopeScaling.YARN]:
        rope_kwargs["original_max_position_embeddings"] = old_max_length
    elif model_args.rope_scaling == RopeScaling.LLAMA3:
        rope_kwargs["original_max_position_embeddings"] = old_max_length
        rope_kwargs["low_freq_factor"] = 1.0
        rope_kwargs["high_freq_factor"] = 4.0
chenych's avatar
chenych committed
80

chenych's avatar
chenych committed
81
    setattr(config, "rope_scaling", rope_kwargs)
luopl's avatar
luopl committed
82
    logger.info_rank0(
chenych's avatar
chenych committed
83
        f"Using {rope_kwargs['rope_type']} scaling strategy and setting scaling factor to {rope_kwargs['factor']}."
chenych's avatar
chenych committed
84
    )