test_linear_cross_entropy.py 15.4 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
#
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os

import torch

import verl.utils.torch_functional as verl_F
from verl.utils.experimental.torch_functional import FusedLinearForPPO
from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy
from verl.utils.torch_functional import logprobs_from_logits

compute_entropy_from_logits = torch.compile(verl_F.entropy_from_logits, dynamic=True)
fused_linear_for_ppo = FusedLinearForPPO()
fused_linear_for_ppo.compile(dynamic=True)

MAX_TEST_CASES = os.environ.get("MAX_TEST_CASES", 5)


def run_torch_entropy(
    hidden: torch.Tensor, weight: torch.Tensor, labels: torch.Tensor, temperature: float, reduction="none"
) -> list[torch.Tensor]:
    hidden = hidden.squeeze(0).to(torch.float32)
    weight = weight.transpose(0, 1).to(torch.float32)
    logits = torch.matmul(hidden, weight)  # [num_tokens, vocab_size]
    logits /= temperature
    pd = torch.nn.functional.softmax(logits, dim=-1)  # [num_tokens, vocab_size]
    entropy_a = torch.logsumexp(logits, dim=-1)  # [num_tokens]
    entropy_b = torch.sum(pd * logits, dim=-1)  # [num_tokens]
    entropy = entropy_a - entropy_b
    logprobs = torch.nn.functional.cross_entropy(logits, labels.squeeze(0), reduction=reduction)  # [num_tokens]
    logprobs = torch.neg(logprobs)
    return logprobs, entropy


def run_verl_original_entropy(
    hidden: torch.Tensor,
    weight: torch.Tensor,
    labels: torch.Tensor,
    temperature: float,
) -> list[torch.Tensor]:
    hidden = hidden.squeeze(0).to(torch.float32)
    weight = weight.transpose(0, 1).to(torch.float32)
    logits = torch.matmul(hidden, weight)  # [num_tokens, vocab_size]
    logits /= temperature
    # compute entropy
    entropy = compute_entropy_from_logits(logits)  # ((total_nnz / sp) + pad)
    # if use_sp: ((total_nnz / sp) + pad) ; if not use_sp: (batch, seqlen)
    logprobs = logprobs_from_logits(logits=logits, labels=labels, inplace_backward=False)
    return logprobs, entropy


# To be tested
def run_verl_torch_fused_entropy(
    hidden: torch.Tensor,
    weight: torch.Tensor,
    labels: torch.Tensor,
    temperature: float,
):
    hidden = hidden.to(torch.float32)
    weight = weight.to(torch.float32)
    logprobs, entropy = fused_linear_for_ppo(
        hidden,
        weight,
        labels,
        temperature=temperature,
    )
    return logprobs.squeeze(0), entropy.squeeze(0)


class TestLinearCrossEntropy:
    def __init__(self, test_case_idx: int, temperature: float = 1.5) -> None:
        self.test_case_idx = test_case_idx
        self.temperature = temperature

    def cleanup(self):
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats()
        import gc

        gc.collect()
        torch.cuda.synchronize()

    def generate_hyper(self):
        global MAX_TEST_CASES

        self.dtype = torch.bfloat16
        if self.test_case_idx == 0:
            self.batch_size = 1
            self.num_tokens = 1937
            self.hidden_size = 3584
            self.vocab_size = 152064
        elif self.test_case_idx == 1:
            self.batch_size = 1
            self.num_tokens = 2169
            self.hidden_size = 896
            self.vocab_size = 151936
        elif self.test_case_idx == 2:
            self.batch_size = 1
            self.num_tokens = 1530
            self.hidden_size = 2048
            self.vocab_size = 32256
        elif self.test_case_idx == 3:
            self.batch_size = 1
            self.num_tokens = 1388
            self.hidden_size = 4096
            self.vocab_size = 102400
        elif self.test_case_idx == 4:
            self.batch_size = 1
            self.num_tokens = 8192
            self.hidden_size = 4096
            self.vocab_size = 102400
        else:
            raise ValueError(f"Invalid test case index: {self.test_case_idx}")
        assert MAX_TEST_CASES <= 5, "MAX_TEST_CASES should be less than or equal to 5."

    def generate_forward_inputs(self):
        hidden = (
            torch.empty((self.batch_size, self.num_tokens, self.hidden_size), dtype=self.dtype, device="cuda")
            .uniform_(-0.5, 0.5)
            .requires_grad_()
        )
        weight = (
            torch.empty((self.vocab_size, self.hidden_size), dtype=self.dtype, device="cuda")
            .uniform_(-0.5, 0.5)
            .requires_grad_()
        )
        labels = torch.randint(0, self.vocab_size, (self.batch_size, self.num_tokens), device="cuda")
        return hidden, weight, labels

    def generate_backward_inputs(self):
        g_entropy = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-0.5, 0.5)
        g_logprobs = torch.empty((self.num_tokens,), dtype=self.dtype, device="cuda").uniform_(-1, 1)
        return g_entropy, g_logprobs

    def verify_correctness(self, iterations=5):
        self.cleanup()
        self.generate_hyper()

        torch_forward_latency = list()
        torch_backward_latency = list()
        verl_forward_latency = list()
        verl_backward_latency = list()
        verl_fused_forward_latency = list()
        verl_fused_backward_latency = list()
        kernel_forward_latency = list()
        kernel_backward_latency = list()

        start_event = torch.cuda.Event(enable_timing=True)
        end_event = torch.cuda.Event(enable_timing=True)

        for i in range(iterations):
            print(f"[INFO]: Iteration {i + 1} / {iterations}...", end="\r")
            hidden, weight, labels = self.generate_forward_inputs()

            start_event.record()
            (torch_logprobs, torch_entropy) = run_torch_entropy(hidden, weight, labels, self.temperature)
            end_event.record()
            torch.cuda.synchronize()
            torch_forward_latency.append(start_event.elapsed_time(end_event))

            start_event.record()
            (verl_logprobs, verl_entropy) = run_verl_original_entropy(hidden, weight, labels, self.temperature)
            end_event.record()
            torch.cuda.synchronize()
            verl_forward_latency.append(start_event.elapsed_time(end_event))

            start_event.record()
            (verl_fused_logprobs, verl_fused_entropy) = run_verl_torch_fused_entropy(
                hidden, weight, labels, self.temperature
            )
            end_event.record()
            torch.cuda.synchronize()
            verl_fused_forward_latency.append(start_event.elapsed_time(end_event))

            start_event.record()
            (kernel_logprobs, kernel_entropy) = linear_cross_entropy(hidden, weight, labels, self.temperature)
            end_event.record()
            torch.cuda.synchronize()
            kernel_forward_latency.append(start_event.elapsed_time(end_event))

            torch.testing.assert_close(torch_logprobs, verl_logprobs, atol=1e-4, rtol=1e-4)
            torch.testing.assert_close(torch_entropy, verl_entropy, atol=1e-4, rtol=1e-4)

            torch.testing.assert_close(torch_logprobs, verl_fused_logprobs, atol=1e-4, rtol=1e-4)
            torch.testing.assert_close(torch_entropy, verl_fused_entropy, atol=1e-4, rtol=1e-4)
            torch.testing.assert_close(verl_logprobs, verl_fused_logprobs, atol=1e-4, rtol=1e-4)
            torch.testing.assert_close(verl_entropy, verl_fused_entropy, atol=1e-4, rtol=1e-4)

            torch.testing.assert_close(torch_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4)
            torch.testing.assert_close(torch_entropy, kernel_entropy, atol=5e-3, rtol=5e-4)
            torch.testing.assert_close(verl_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4)
            torch.testing.assert_close(verl_entropy, kernel_entropy, atol=5e-3, rtol=5e-4)
            torch.testing.assert_close(verl_fused_logprobs, kernel_logprobs, atol=1e-3, rtol=2e-4)
            torch.testing.assert_close(verl_fused_entropy, kernel_entropy, atol=5e-3, rtol=5e-4)

            # backward
            g_entropy, g_logprobs = self.generate_backward_inputs()

            start_event.record()
            (d_torch_hidden, d_torch_weight) = torch.autograd.grad(
                (torch_entropy, torch_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
            )
            end_event.record()
            torch.cuda.synchronize()
            torch_backward_latency.append(start_event.elapsed_time(end_event))

            start_event.record()
            (d_verl_hidden, d_verl_weight) = torch.autograd.grad(
                (verl_entropy, verl_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
            )
            end_event.record()
            torch.cuda.synchronize()
            verl_backward_latency.append(start_event.elapsed_time(end_event))

            start_event.record()
            (d_verl_fused_hidden, d_verl_fused_weight) = torch.autograd.grad(
                (verl_fused_entropy, verl_fused_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
            )
            end_event.record()
            torch.cuda.synchronize()
            verl_fused_backward_latency.append(start_event.elapsed_time(end_event))

            start_event.record()
            (d_kernel_hidden, d_kernel_weight) = torch.autograd.grad(
                (kernel_entropy, kernel_logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
            )
            end_event.record()
            torch.cuda.synchronize()
            kernel_backward_latency.append(start_event.elapsed_time(end_event))

            torch.testing.assert_close(d_torch_hidden, d_verl_hidden, atol=1e-2, rtol=1e-4)
            torch.testing.assert_close(d_torch_weight, d_verl_weight, atol=1e-2, rtol=1e-4)

            torch.testing.assert_close(d_torch_hidden, d_verl_fused_hidden, atol=1e-2, rtol=1e-4)
            torch.testing.assert_close(d_torch_weight, d_verl_fused_weight, atol=1e-2, rtol=1e-4)
            torch.testing.assert_close(d_verl_hidden, d_verl_fused_hidden, atol=1e-2, rtol=1e-4)
            torch.testing.assert_close(d_verl_weight, d_verl_fused_weight, atol=1e-2, rtol=1e-4)
            torch.testing.assert_close(d_torch_hidden, d_verl_hidden, atol=1e-2, rtol=1e-4)
            torch.testing.assert_close(d_torch_weight, d_verl_weight, atol=1e-2, rtol=1e-4)

            torch.testing.assert_close(d_torch_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2)
            torch.testing.assert_close(d_torch_weight, d_kernel_weight, atol=2e-2, rtol=4e-2)
            torch.testing.assert_close(d_verl_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2)
            torch.testing.assert_close(d_verl_weight, d_kernel_weight, atol=2e-2, rtol=4e-2)
            torch.testing.assert_close(d_verl_fused_hidden, d_kernel_hidden, atol=2e-2, rtol=4e-2)
            torch.testing.assert_close(d_verl_fused_weight, d_kernel_weight, atol=2e-2, rtol=4e-2)

        # remove first latency
        torch_forward_latency = torch_forward_latency[1:]
        torch_backward_latency = torch_backward_latency[1:]
        verl_forward_latency = verl_forward_latency[1:]
        verl_backward_latency = verl_backward_latency[1:]
        verl_fused_forward_latency = verl_fused_forward_latency[1:]
        verl_fused_backward_latency = verl_fused_backward_latency[1:]
        kernel_forward_latency = kernel_forward_latency[1:]
        kernel_backward_latency = kernel_backward_latency[1:]

        print("\n[INFO]: Verified forward & backward correctness.")

        print(
            f"[INFO]: Forward pass: Torch implementation average time: "
            f"{sum(torch_forward_latency) / len(torch_forward_latency):.2f} ms"
        )
        print(
            f"[INFO]: Backward pass: torch implementation average time: "
            f"{sum(torch_backward_latency) / len(torch_backward_latency):.2f} ms"
        )
        print(
            f"[INFO]: Forward pass: VeRL implementation average time: "
            f"{sum(verl_forward_latency) / len(verl_forward_latency):.2f} ms"
        )
        print(
            f"[INFO]: Backward pass: VeRL implementation average time: "
            f"{sum(verl_backward_latency) / len(verl_backward_latency):.2f} ms"
        )
        print(
            f"[INFO]: Forward pass: VeRL Fused Entropy implementation average time: "
            f"{sum(verl_fused_forward_latency) / len(verl_fused_forward_latency):.2f} ms"
        )
        print(
            f"[INFO]: Backward pass: VeRL Fused Entropy implementation average time: "
            f"{sum(verl_fused_backward_latency) / len(verl_fused_backward_latency):.2f} ms"
        )
        print(
            f"[INFO]: Forward pass: Kernel implementation average time: "
            f"{sum(kernel_forward_latency) / len(kernel_forward_latency):.2f} ms"
        )
        print(
            f"[INFO]: Backward pass: kernel implementation average time: "
            f"{sum(kernel_backward_latency) / len(kernel_backward_latency):.2f} ms"
        )

    def check_storage(self, method_name, run_forward):
        self.cleanup()
        self.generate_hyper()

        hidden, weight, labels = self.generate_forward_inputs()

        torch.cuda.reset_peak_memory_stats()
        (logprobs, entropy) = run_forward(hidden, weight, labels, self.temperature)
        torch.cuda.synchronize()
        torch_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024
        print(f"[INFO]: {method_name} Forward pass peak memory: {torch_max_memory:.2f} MB")

        g_entropy, g_logprobs = self.generate_backward_inputs()

        torch.cuda.reset_peak_memory_stats()
        (d_torch_hidden, d_torch_weight) = torch.autograd.grad(
            (entropy, logprobs), (hidden, weight), (g_entropy, g_logprobs), retain_graph=False
        )
        torch.cuda.synchronize()
        torch_backward_max_memory = torch.cuda.max_memory_allocated() / 1024 / 1024
        print(f"[INFO]: {method_name} Backward pass peak memory: {torch_backward_max_memory:.2f} MB")

    def check_storage_all(self):
        self.check_storage("Torch", run_torch_entropy)
        self.check_storage("VeRL", run_verl_original_entropy)
        self.check_storage("VeRL Torch Fused", run_verl_torch_fused_entropy)
        self.check_storage("Kernel", linear_cross_entropy)


if __name__ == "__main__":
    # torch.cuda.memory._record_memory_history()

    for test_case_idx in range(MAX_TEST_CASES):
        print(f"[INFO] Running test case {test_case_idx}")
        test = TestLinearCrossEntropy(test_case_idx)

        test.verify_correctness()
        test.check_storage_all()

    # torch.cuda.memory._dump_snapshot("test_linear_cross_entropy.pkl")