utils.py 1.98 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
# Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Utility functions for Transformer Engine modules"""

from typing import Union

import paddle
import paddle.nn.functional as F


def cast_if_needed(tensor: Union[paddle.Tensor, None],
                   dtype: paddle.dtype) -> Union[paddle.Tensor, None]:
    """Cast tensor to dtype"""
    return tensor if tensor is None or tensor.dtype == dtype else paddle.cast(tensor, dtype)


18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
def cast_if_needed_inplace(tensor: Union[paddle.Tensor, None],
                           dtype: paddle.dtype) -> Union[paddle.Tensor, None]:
    """Cast tensor to dtype (inplace), not to be used on layer inputs"""
    return tensor if tensor is None or tensor.dtype == dtype else tensor._to(dtype=dtype)


def check_dim_for_fp8_forward_exec(tensor: paddle.Tensor) -> bool:
    """For fp8 fprop (TN layout), inputs and weights must be such
       that dim0 is divisible by 8 and dim1 is divisible by 16.
    """
    return not tensor.shape[0] % 8 and not tensor.shape[1] % 16


def assert_dim_for_fp8_forward_exec(tensor: paddle.Tensor) -> None:
    """For fp8 fprop (TN layout), inputs and weights must be such
       that dim0 is divisible by 8 and dim1 is divisible by 16.
    """
    # single tensor check so it's clear which tensor is triggering the assertion
    assert check_dim_for_fp8_forward_exec(tensor), (
        "Tensor dimensions are not compatible for FP8 execution: "
        f"({tensor.shape[0]} % 8 != 0, {tensor.shape[1]} % 16 != 0)")


def get_bias_dtype(activation_dtype: paddle.dtype):
    """Get bias dtype given activation_dtype"""
    return paddle.bfloat16 if activation_dtype == paddle.float32 else activation_dtype


46
47
48
49
50
51
52
53
54
def get_paddle_act_func(activation):
    """Get paddle activation function"""
    funcs = {
        'gelu': F.gelu,
        'relu': F.relu,
    }
    if activation not in funcs:
        raise "Activation type " + activation + " is not supported."
    return funcs[activation]