utils.py 5.02 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
3
4
5
#
# See LICENSE for license information.
"""Utility functions for Transformer Engine modules"""

Tian Zheng's avatar
Tian Zheng committed
6
from typing import Optional, Tuple, Union
7
8
9

import paddle
import paddle.nn.functional as F
10
from .cpp_extensions import swiglu_pd
11
12


13
14
15
def cast_if_needed(
    tensor: Union[paddle.Tensor, None], dtype: paddle.dtype
) -> Union[paddle.Tensor, None]:
16
17
18
19
    """Cast tensor to dtype"""
    return tensor if tensor is None or tensor.dtype == dtype else paddle.cast(tensor, dtype)


20
21
22
def cast_if_needed_inplace(
    tensor: Union[paddle.Tensor, None], dtype: paddle.dtype
) -> Union[paddle.Tensor, None]:
23
24
25
26
27
28
    """Cast tensor to dtype (inplace), not to be used on layer inputs"""
    return tensor if tensor is None or tensor.dtype == dtype else tensor._to(dtype=dtype)


def check_dim_for_fp8_forward_exec(tensor: paddle.Tensor) -> bool:
    """For fp8 fprop (TN layout), inputs and weights must be such
29
    that dim0 is divisible by 8 and dim1 is divisible by 16.
30
31
32
33
34
35
    """
    return not tensor.shape[0] % 8 and not tensor.shape[1] % 16


def assert_dim_for_fp8_forward_exec(tensor: paddle.Tensor) -> None:
    """For fp8 fprop (TN layout), inputs and weights must be such
36
    that dim0 is divisible by 8 and dim1 is divisible by 16.
37
38
39
40
    """
    # single tensor check so it's clear which tensor is triggering the assertion
    assert check_dim_for_fp8_forward_exec(tensor), (
        "Tensor dimensions are not compatible for FP8 execution: "
41
42
        f"({tensor.shape[0]} % 8 != 0, {tensor.shape[1]} % 16 != 0)"
    )
43
44
45
46
47
48
49


def get_bias_dtype(activation_dtype: paddle.dtype):
    """Get bias dtype given activation_dtype"""
    return paddle.bfloat16 if activation_dtype == paddle.float32 else activation_dtype


50
51
52
def get_paddle_act_func(activation):
    """Get paddle activation function"""
    funcs = {
53
54
55
56
        "gelu": F.gelu,
        "relu": F.relu,
        "silu": F.silu,
        "swiglu": swiglu_pd,
57
58
59
60
    }
    if activation not in funcs:
        raise "Activation type " + activation + " is not supported."
    return funcs[activation]
Shijie's avatar
Shijie committed
61
62


63
64
65
def attention_mask_func(
    attention_scores: paddle.Tensor, attention_mask: paddle.Tensor
) -> paddle.Tensor:
Shijie's avatar
Shijie committed
66
67
68
69
70
71
72
73
74
75
76
77
    """Get attention mask"""

    def _masked_fill(x, mask, value):
        y = paddle.full(x.shape, value, x.dtype)
        return paddle.where(mask, y, x)

    attention_scores = _masked_fill(attention_scores, attention_mask, -10000.0)
    return attention_scores


def mask_to_cu_seqlens(mask: paddle.Tensor, need_kv: bool = False) -> paddle.Tensor:
    """Convert mask to cu_seqlens"""
78
    assert "bool" in str(mask.dtype), "mask must be bool dtype"
Shijie's avatar
Shijie committed
79
    assert len(mask.shape) == 4 and mask.shape[1] == 1, "mask must be [b, 1, s_q, s_kv]"
80
    q_actual_seqlens = paddle.sum(mask[:, :, :, 0].logical_not(), axis=(-1, -2), dtype="int32")
Shijie's avatar
Shijie committed
81
82
83
84
    q_cu_seqlens = paddle.cumsum(q_actual_seqlens)
    q_cu_seqlens = paddle.concat([paddle.zeros([1], dtype=paddle.int32), q_cu_seqlens], axis=0)
    if not need_kv:
        return q_cu_seqlens, None
85
    kv_actual_seqlens = paddle.sum(mask[:, :, 0, :].logical_not(), axis=(-1, -2), dtype="int32")
Shijie's avatar
Shijie committed
86
87
88
89
90
91
92
93
    kv_cu_seqlens = paddle.cumsum(kv_actual_seqlens)
    kv_cu_seqlens = paddle.concat([paddle.zeros([1], dtype=paddle.int32), kv_cu_seqlens], axis=0)
    return q_cu_seqlens, kv_cu_seqlens


def divide(numerator: int, denominator: int) -> int:
    """Ensure that numerator is divisible by the denominator and return
    the division value."""
94
    assert numerator % denominator == 0, f"{numerator} is not divisible by {denominator}"
Shijie's avatar
Shijie committed
95
    return numerator // denominator
Tian Zheng's avatar
Tian Zheng committed
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116


def save_for_backward_allow_none(ctx, *args) -> None:
    """Save tensors for backward. Args could be None"""
    indices_mapping = []
    tensors_to_save = []
    for x in args:
        if isinstance(x, paddle.Tensor):
            indices_mapping.append(len(tensors_to_save))
            tensors_to_save.append(x)
        elif x is None:
            indices_mapping.append(-1)
        else:
            raise ValueError(f"Type {type(x)} is not allowed.")

    ctx._indices_mapping = indices_mapping
    ctx.save_for_backward(*tensors_to_save)


def saved_tensor_allow_none(ctx) -> Tuple[Optional[paddle.Tensor]]:
    """Used with `save_for_backward_allow_none` in pair. Get saved tensors from ctx."""
117
118
119
    assert hasattr(
        ctx, "_indices_mapping"
    ), "`saved_tensor_allow_none` must be used with `save_for_backward_allow_none` in pair."
Tian Zheng's avatar
Tian Zheng committed
120
121
122
123
124
125
126
127
128
129
130
131

    indices_mapping = ctx._indices_mapping
    outputs = []
    saved_tensors = ctx.saved_tensor()

    for index in indices_mapping:
        if index < 0:
            outputs.append(None)
        else:
            outputs.append(saved_tensors[index])

    return tuple(outputs)
132
133
134
135
136
137
138
139


def clear_tensor_data(*tensors: Tuple[Optional[paddle.Tensor], ...]) -> None:
    """
    Free tensor buffer
    """

    def can_free(t):
140
141
142
143
144
145
        return (
            t is not None
            and isinstance(t, paddle.Tensor)
            and t._is_initialized()
            and t.inplace_version == 0
        )
146
147
148
149

    for t in tensors:
        if can_free(t):
            t._clear_dataptr()