rope.py 2.96 KB
Newer Older
chenych's avatar
chenych committed
1
# Copyright 2025 LMSYS and the LlamaFactory team.
chenych's avatar
chenych committed
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# This code is inspired by the LMSYS's FastChat library.
# https://github.com/lm-sys/FastChat/blob/v0.2.30/fastchat/train/train.py
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import math
from typing import TYPE_CHECKING

luopl's avatar
luopl committed
22
from ...extras import logging
chenych's avatar
chenych committed
23
from ...extras.constants import RopeScaling
chenych's avatar
chenych committed
24
25
26
27
28
29
30
31


if TYPE_CHECKING:
    from transformers import PretrainedConfig

    from ...hparams import ModelArguments


luopl's avatar
luopl committed
32
logger = logging.get_logger(__name__)
chenych's avatar
chenych committed
33
34
35
36
37
38
39


def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
    if model_args.rope_scaling is None:
        return

    if not hasattr(config, "rope_scaling"):
luopl's avatar
luopl committed
40
        logger.warning_rank0("Current model does not support RoPE scaling.")
chenych's avatar
chenych committed
41
42
        return

chenych's avatar
chenych committed
43
    rope_kwargs = {"rope_type": getattr(model_args.rope_scaling, "value", model_args.rope_scaling)}  # handle enum
chenych's avatar
chenych committed
44
    if model_args.model_max_length is not None:
chenych's avatar
chenych committed
45
        if is_trainable and model_args.rope_scaling == RopeScaling.DYNAMIC:
luopl's avatar
luopl committed
46
            logger.warning_rank0(
chenych's avatar
chenych committed
47
48
49
50
51
                "Dynamic NTK scaling may not work well with fine-tuning. "
                "See: https://github.com/huggingface/transformers/pull/24653"
            )

        current_max_length = getattr(config, "max_position_embeddings", None)
chenych's avatar
chenych committed
52
53
54
        if (not current_max_length) or model_args.model_max_length <= current_max_length:
            logger.warning_rank0("Input length is smaller than max length. Disabling rope scaling.")
            return
chenych's avatar
chenych committed
55

chenych's avatar
chenych committed
56
57
58
59
        logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
        setattr(config, "max_position_embeddings", model_args.model_max_length)
        rope_kwargs["factor"] = float(math.ceil(model_args.model_max_length / current_max_length))
        if model_args.rope_scaling == RopeScaling.DYNAMIC:
chenych's avatar
chenych committed
60
            rope_kwargs["original_max_position_embeddings"] = current_max_length
chenych's avatar
chenych committed
61
        elif model_args.rope_scaling == RopeScaling.LLAMA3:
chenych's avatar
chenych committed
62
63
64
            rope_kwargs["original_max_position_embeddings"] = current_max_length
            rope_kwargs["low_freq_factor"] = 1.0
            rope_kwargs["high_freq_factor"] = 4.0
chenych's avatar
chenych committed
65
    else:
chenych's avatar
chenych committed
66
        rope_kwargs["factor"] = 2.0
chenych's avatar
chenych committed
67

chenych's avatar
chenych committed
68
    setattr(config, "rope_scaling", rope_kwargs)
luopl's avatar
luopl committed
69
    logger.info_rank0(
chenych's avatar
chenych committed
70
        f"Using {rope_kwargs['rope_type']} scaling strategy and setting scaling factor to {rope_kwargs['factor']}."
chenych's avatar
chenych committed
71
    )