quickstart_utils.py 7.94 KB
Newer Older
1
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
Przemek Tredak's avatar
Przemek Tredak committed
2
3
4
5
6
7
#
# See LICENSE for license information.

import math
from typing import Callable, Optional
import torch
8
9
import transformer_engine.pytorch as te
from transformer_engine.pytorch.fp8 import DelayedScaling, dist_group_type
Przemek Tredak's avatar
Przemek Tredak committed
10
11
12
13
14

def speedometer(
        module: torch.nn.Module,
        input: torch.Tensor,
        output_grad: torch.Tensor,
15
16
        forward_kwargs: dict = {},
        fp8_autocast_kwargs: Optional[dict] = None,
Przemek Tredak's avatar
Przemek Tredak committed
17
18
19
20
21
22
23
24
25
        timing_iters: int = 50,
        warmup_iters: int = 50,
) -> None:
    """Measure average run time for a PyTorch module

    Performs forward and backward passes.
    """
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
26
27
    if fp8_autocast_kwargs is None:
        fp8_autocast_kwargs = { "enabled": False }
Przemek Tredak's avatar
Przemek Tredak committed
28
29
30
31

    # Warmup runs
    torch.cuda.synchronize()
    for _ in range(warmup_iters):
32
33
        with te.fp8_autocast(**fp8_autocast_kwargs):
            output = module(input, **forward_kwargs)
Przemek Tredak's avatar
Przemek Tredak committed
34
35
36
37
38
        output.backward(output_grad)

    # Timing runs
    start.record()
    for _ in range(timing_iters):
39
40
        with te.fp8_autocast(**fp8_autocast_kwargs):
            output = module(input, **forward_kwargs)
Przemek Tredak's avatar
Przemek Tredak committed
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
        output.backward(output_grad)
    end.record()
    torch.cuda.synchronize()

    print(f"Mean time: {start.elapsed_time(end)/timing_iters} ms")


class DotProductAttention(torch.nn.Module):
    """Attention operation in Transformer layer

    Built with plain PyTorch modules.

    """
    def __init__(
            self,
            num_attention_heads: int,
            kv_channels: int,
            attention_dropout: float,
    ) -> None:
        super().__init__()
        self.projection_size = kv_channels * num_attention_heads
        self.hidden_size_per_attention_head = kv_channels
        self.norm_factor = math.sqrt(self.hidden_size_per_attention_head)
        self.dropout = torch.nn.Dropout(attention_dropout)

    def masked_softmax(
            self,
            inp: torch.Tensor,
            mask: Optional[torch.Tensor]
    ) -> torch.Tensor:
        if mask is not None:
            inp.masked_fill_(mask, -10000.0)
        return torch.nn.Softmax(dim=-1)(inp)

    def forward(
            self,
            query: torch.Tensor,
            key: torch.Tensor,
            value: torch.Tensor,
            attention_mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        b = query.size(1)
        np = query.size(2)
        sq = query.size(0)
        sk = key.size(0)
        hn = value.size(3)

        # [sq, b, np, hn] -> [sq, b * np, hn]
        query = query.view(sq, b * np, -1)
        # [sk, b, np, hn] -> [sk, b * np, hn]
        key = key.view(sk, b * np, -1)

        bmm1 = torch.bmm(query.transpose(0, 1), key.transpose(0, 1).transpose(1, 2)) / self.norm_factor

        # change view to [b, np, sq, sk]
        attention_scores = bmm1.view(b, np, sq, sk)

        attention_probs = self.masked_softmax(attention_scores, attention_mask)

        attention_probs = self.dropout(attention_probs)

        # change view [sk, b * np, hn]
        value = value.view(sk, b * np, -1)

        # change view [b * np, sq, sk]
        attention_probs = attention_probs.view(b * np, sq, -1)

        # matmul: [b * np, sq, hn]
        context = torch.bmm(attention_probs, value.transpose(0, 1))

        # change view [b, np, sq, hn]
        context = context.view(b, np, sq, hn)

        # [b, np, sq, hn] --> [sq, b, np, hn]
        context = context.permute(2, 0, 1, 3).contiguous()

        # [sq, b, np, hn] --> [sq, b, hp]
        context = context.view(sq, b, self.projection_size)

        return context


class BasicMLP(torch.nn.Module):
    """Feed-forward network in Transformer layer

    Built with plain PyTorch modules.

    """
    def __init__(
            self,
            hidden_size: int,
            ffn_hidden_size: int,
    ) -> None:
        super().__init__()
        self.linear1 = torch.nn.Linear(hidden_size, ffn_hidden_size, bias=True)
        self.linear2 = torch.nn.Linear(ffn_hidden_size, hidden_size, bias=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.linear1(x)
        x = torch.nn.functional.gelu(x, approximate='tanh')
        x = self.linear2(x)
        return x


def share_parameters_with_basic_te_model(te_model, basic_model):
    """Initialize parameters for TE Transformer layer with basic modules

    Parameter values are copied from pure PyTorch implementation.

    """
    te_model.ln1.weight= basic_model.ln1.weight
    te_model.ln1.bias = basic_model.ln1.bias
    te_model.qkv_projection.weight = basic_model.qkv_projection.weight
    te_model.qkv_projection.bias = basic_model.qkv_projection.bias
    te_model.projection.weight = basic_model.projection.weight
    te_model.projection.bias = basic_model.projection.bias
    te_model.ln2.weight = basic_model.ln2.weight
    te_model.ln2.bias = basic_model.ln2.bias
    te_model.mlp.linear1.weight = basic_model.mlp.linear1.weight
    te_model.mlp.linear1.bias = basic_model.mlp.linear1.bias
    te_model.mlp.linear2.weight = basic_model.mlp.linear2.weight
    te_model.mlp.linear2.bias = basic_model.mlp.linear2.bias


def share_parameters_with_fused_te_model(te_model, basic_model):
    """Initialize parameters for TE Transformer layer with fused modules

    Parameter values are copied from pure PyTorch implementation.

    """
    te_model.ln_qkv.layer_norm_weight = basic_model.ln1.weight
    te_model.ln_qkv.layer_norm_bias = basic_model.ln1.bias
    te_model.ln_qkv.weight = basic_model.qkv_projection.weight
    te_model.ln_qkv.bias = basic_model.qkv_projection.bias
    te_model.projection.weight = basic_model.projection.weight
    te_model.projection.bias = basic_model.projection.bias
    te_model.ln_mlp.layer_norm_weight = basic_model.ln2.weight
    te_model.ln_mlp.layer_norm_bias = basic_model.ln2.bias
    te_model.ln_mlp.fc1_weight = basic_model.mlp.linear1.weight
    te_model.ln_mlp.fc1_bias = basic_model.mlp.linear1.bias
    te_model.ln_mlp.fc2_weight = basic_model.mlp.linear2.weight
    te_model.ln_mlp.fc2_bias = basic_model.mlp.linear2.bias


def share_parameters_with_transformerlayer_te_model(te_model, basic_model):
    """Initialize parameters for monolithic TE Transformer layer

    Parameter values are copied from pure PyTorch implementation.

    """
    te_model.self_attention.layernorm_qkv.layer_norm_weight = basic_model.ln1.weight
    te_model.self_attention.layernorm_qkv.layer_norm_bias = basic_model.ln1.bias
    te_model.self_attention.layernorm_qkv.weight = basic_model.qkv_projection.weight
    te_model.self_attention.layernorm_qkv.bias = basic_model.qkv_projection.bias
    te_model.self_attention.proj.weight = basic_model.projection.weight
    te_model.self_attention.proj.bias = basic_model.projection.bias
    te_model.layernorm_mlp.layer_norm_weight = basic_model.ln2.weight
    te_model.layernorm_mlp.layer_norm_bias = basic_model.ln2.bias
    te_model.layernorm_mlp.fc1_weight = basic_model.mlp.linear1.weight
    te_model.layernorm_mlp.fc1_bias = basic_model.mlp.linear1.bias
    te_model.layernorm_mlp.fc2_weight = basic_model.mlp.linear2.weight
    te_model.layernorm_mlp.fc2_bias = basic_model.mlp.linear2.bias


def cast_to_representable(inp, scale = 1., fp8_format='e4m3'):
    import transformer_engine.pytorch.cpp_extensions as texcpp
    import transformer_engine_extensions as tex
208
    from transformer_engine.pytorch.constants import TE_DType
Przemek Tredak's avatar
Przemek Tredak committed
209
    fp8_type = tex.DType.kFloat8E4M3 if fp8_format == 'e4m3' else tex.DType.kFloat8E5M2
210
    input_type = TE_DType[inp.dtype]
Przemek Tredak's avatar
Przemek Tredak committed
211
212
213
214
215
216
217
    meta = tex.FP8TensorMeta()
    meta.scale = torch.ones(1,dtype=torch.float32, device="cuda") * scale
    meta.scale_inv = torch.ones(1, dtype=torch.float32, device="cuda") / scale
    meta.amax_history = torch.zeros(1, 1, dtype=torch.float32, device="cuda")
    ret = texcpp.cast_to_fp8(inp, meta, tex.FP8FwdTensors.GEMM1_INPUT, fp8_type)
    ret = texcpp.cast_from_fp8(ret, meta, tex.FP8FwdTensors.GEMM1_INPUT, fp8_type, input_type)
    return ret