utils.py 2.38 KB
Newer Older
xingjinliang's avatar
xingjinliang committed
1
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6
7
8

"""Utilities for models."""

import math

import torch

xingjinliang's avatar
xingjinliang committed
9
10
11
from megatron.training import get_args
from megatron.legacy.model import LayerNorm, RMSNorm
from megatron.core.jit import jit_fuser
12
13
14
15
16
17
18
19
20
21
22
23

def init_method_normal(sigma):
    """Init method based on N(0, sigma)."""
    def init_(tensor):
        return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)

    return init_


def scaled_init_method_normal(sigma, num_layers):
    """Init method based on N(0, sigma/sqrt(2*num_layers)."""
    std = sigma / math.sqrt(2.0 * num_layers)
Neel Kant's avatar
Neel Kant committed
24

25
26
27
28
29
30
    def init_(tensor):
        return torch.nn.init.normal_(tensor, mean=0.0, std=std)

    return init_


31
32
33
34
35
def attention_mask_func(attention_scores, attention_mask):
    attention_scores.masked_fill_(attention_mask, -10000.0)
    return attention_scores


36
37
38
def get_linear_layer(rows, columns, init_method):
    """Simple linear layer with weight initialization."""
    layer = torch.nn.Linear(rows, columns)
39
40
    if get_args().perform_initialization:
        init_method(layer.weight)
41
42
43
44
    with torch.no_grad():
        layer.bias.zero_()
    return layer

xingjinliang's avatar
xingjinliang committed
45
46

@jit_fuser
47
48
49
def gelu_impl(x):
    """OpenAI's gelu implementation."""
    return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
xingjinliang's avatar
xingjinliang committed
50

51
                                       (1.0 + 0.044715 * x * x)))
52
def openai_gelu(x):
53
54
    return gelu_impl(x)

xingjinliang's avatar
xingjinliang committed
55

56
#This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
xingjinliang's avatar
xingjinliang committed
57
@jit_fuser
58
59
def erf_gelu(x):
    return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype)+torch.ones_like(x).to(dtype=x.dtype))
xingjinliang's avatar
xingjinliang committed
60

silencealiang's avatar
silencealiang committed
61

xingjinliang's avatar
xingjinliang committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def get_norm(config):
    args = get_args()
    if args.normalization == "LayerNorm":
        return LayerNorm(
            config.hidden_size,
            eps=config.layernorm_epsilon,
            no_persist_layer_norm=not config.persist_layer_norm,
            sequence_parallel=config.sequence_parallel,
            apply_layernorm_1p=args.apply_layernorm_1p)
    elif args.normalization == "RMSNorm":
        if args.apply_layernorm_1p:
            raise NotImplementedError('RMSNorm does not currently support the layernorm_1p formulation.')

        return RMSNorm(dim=config.hidden_size,
                       eps=config.layernorm_epsilon,
                       sequence_parallel=config.sequence_parallel)
    else:
        raise Exception(f"unsupported norm type '{args.normalization}'.")