utils.py 2.38 KB
Newer Older
liangjing's avatar
liangjing committed
1
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
2
3
4
5
6
7
8

"""Utilities for models."""

import math

import torch

liangjing's avatar
liangjing committed
9
10
11
from megatron.training import get_args
from megatron.legacy.model import LayerNorm, RMSNorm
from megatron.core.jit import jit_fuser
12
13
14
15
16
17
18
19
20
21
22
23

def init_method_normal(sigma):
    """Init method based on N(0, sigma)."""
    def init_(tensor):
        return torch.nn.init.normal_(tensor, mean=0.0, std=sigma)

    return init_


def scaled_init_method_normal(sigma, num_layers):
    """Init method based on N(0, sigma/sqrt(2*num_layers)."""
    std = sigma / math.sqrt(2.0 * num_layers)
Neel Kant's avatar
Neel Kant committed
24

25
26
27
28
29
30
    def init_(tensor):
        return torch.nn.init.normal_(tensor, mean=0.0, std=std)

    return init_


31
32
33
34
35
def attention_mask_func(attention_scores, attention_mask):
    attention_scores.masked_fill_(attention_mask, -10000.0)
    return attention_scores


36
37
38
def get_linear_layer(rows, columns, init_method):
    """Simple linear layer with weight initialization."""
    layer = torch.nn.Linear(rows, columns)
39
40
    if get_args().perform_initialization:
        init_method(layer.weight)
41
42
43
44
    with torch.no_grad():
        layer.bias.zero_()
    return layer

liangjing's avatar
liangjing committed
45
46

@jit_fuser
47
48
49
def gelu_impl(x):
    """OpenAI's gelu implementation."""
    return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * x *
liangjing's avatar
liangjing committed
50

51
                                       (1.0 + 0.044715 * x * x)))
52
def openai_gelu(x):
53
54
    return gelu_impl(x)

liangjing's avatar
liangjing committed
55

56
#This is actually Python equivalent of torch.nn.functional.gelu(), also with type hints for ONNX exporter
liangjing's avatar
liangjing committed
57
@jit_fuser
58
59
def erf_gelu(x):
    return x * 0.5 * (torch.erf(x / 1.41421).to(dtype=x.dtype)+torch.ones_like(x).to(dtype=x.dtype))
liangjing's avatar
liangjing committed
60

unknown's avatar
unknown committed
61

liangjing's avatar
liangjing committed
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
def get_norm(config):
    args = get_args()
    if args.normalization == "LayerNorm":
        return LayerNorm(
            config.hidden_size,
            eps=config.layernorm_epsilon,
            no_persist_layer_norm=not config.persist_layer_norm,
            sequence_parallel=config.sequence_parallel,
            apply_layernorm_1p=args.apply_layernorm_1p)
    elif args.normalization == "RMSNorm":
        if args.apply_layernorm_1p:
            raise NotImplementedError('RMSNorm does not currently support the layernorm_1p formulation.')

        return RMSNorm(dim=config.hidden_size,
                       eps=config.layernorm_epsilon,
                       sequence_parallel=config.sequence_parallel)
    else:
        raise Exception(f"unsupported norm type '{args.normalization}'.")