test_grpo_loss.py 43.6 KB
Newer Older
cmx's avatar
cmx committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
import pytest
import torch
import torch.nn.functional as F

from test.utils import assert_verbose_allclose
from test.utils import infer_device
from test.utils import set_seed

from liger_kernel.ops.grpo_loss import fused_selective_log_softmax
from liger_kernel.transformers.grpo_loss import triton_grpo_loss


@torch.no_grad
def selective_log_softmax(logits, input_ids, temperature=0.9):
    logits = logits[:, :-1, :]  # (B, L-1, V), exclude the last logit: it corresponds to the next token pred
    logits_to_keep = logits.size(1)
    index = input_ids[:, -logits_to_keep:]
    logits = logits[:, -logits_to_keep:]
    logits = logits / temperature

    if logits.dtype in [torch.float32, torch.float64]:
        selected_logits = torch.gather(logits, dim=-1, index=index.unsqueeze(-1)).squeeze(-1)
        # loop to reduce peak mem consumption
        logsumexp_values = torch.stack([torch.logsumexp(lg, dim=-1) for lg in logits])
        per_token_logps = selected_logits - logsumexp_values  # log_softmax(x_i) = x_i - logsumexp(x)
    else:
        # logsumexp approach is unstable with bfloat16, fall back to slightly less efficent approach
        per_token_logps = []
        for row_logits, row_labels in zip(logits, index):  # loop to reduce peak mem consumption
            row_logps = F.log_softmax(row_logits, dim=-1)
            row_per_token_logps = row_logps.gather(dim=-1, index=row_labels.unsqueeze(-1)).squeeze(-1)
            per_token_logps.append(row_per_token_logps)
        per_token_logps = torch.stack(per_token_logps)
    return per_token_logps


def _get_log_probs(logits, input_ids):
    """Helper function to compute per-token log probabilities."""
    per_token_logps = []
    for logits_row, input_ids_row in zip(logits, input_ids[:, -logits.size(1) :]):
        log_probs = logits_row.log_softmax(dim=-1)
        token_log_prob = torch.gather(log_probs, dim=1, index=input_ids_row.unsqueeze(1)).squeeze(1)
        per_token_logps.append(token_log_prob)
    return torch.stack(per_token_logps)


def torch_grpo_loss(
    logits,
    old_logp,
    ref_logp,
    completion_ids,
    advantages,
    completion_mask,
    temperature,
    beta,
    eps_low,
    eps_high,
    delta=None,
    use_bias_correction_kl=False,
):
    assert logits.is_contiguous() and completion_ids.is_contiguous()
    assert old_logp is None or old_logp.is_contiguous()
    assert (ref_logp is not None and ref_logp.is_contiguous()) if beta != 0.0 else True
    logits = logits[:, :-1]

    per_token_logps = _get_log_probs(logits / temperature, completion_ids)
    ref_per_token_logps = ref_logp

    if old_logp is None:
        old_logp = per_token_logps.detach()
    coef_1 = torch.exp(per_token_logps - old_logp)
    coef_2 = torch.clamp(coef_1, 1 - eps_low, 1 + eps_high)
    if delta is not None:
        coef_1 = torch.clamp(coef_1, max=delta)
    per_token_loss1 = coef_1 * advantages.unsqueeze(1)
    per_token_loss2 = coef_2 * advantages.unsqueeze(1)
    per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
    per_token_loss = per_token_loss * completion_mask if completion_mask is not None else per_token_loss

    per_token_kl = None
    if beta != 0.0:
        per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
        if use_bias_correction_kl:
            per_token_kl = per_token_kl * torch.exp(per_token_logps - old_logp)
        if completion_mask is not None:
            per_token_kl *= completion_mask
        per_token_loss = per_token_loss + beta * per_token_kl
    is_clipped = (per_token_loss1 < per_token_loss2).float()
    return per_token_loss, per_token_kl, is_clipped


def torch_cispo_loss(
    logits,
    old_logp,
    ref_logp,
    completion_ids,
    advantages,
    completion_mask,
    temperature,
    beta,
    eps_high,
    use_bias_correction_kl=False,
):
    """Reference implementation for CISPO loss.

    CISPO (Clipped Importance Sampling Policy Optimization) uses:
    - Upper-bound only clipping (no lower bound)
    - Detached clipped coefficient (no gradient through clipping)
    - Loss includes per_token_logps multiplication

    Reference: MiniMax-M1 technical report
    """
    assert logits.is_contiguous() and completion_ids.is_contiguous()
    assert old_logp is None or old_logp.is_contiguous()
    assert (ref_logp is not None and ref_logp.is_contiguous()) if beta != 0.0 else True
    logits = logits[:, :-1]

    per_token_logps = _get_log_probs(logits / temperature, completion_ids)
    ref_per_token_logps = ref_logp

    if old_logp is None:
        old_logp = per_token_logps.detach()
    coef_1 = torch.exp(per_token_logps - old_logp)
    # CISPO: upper-bound only clipping with detach
    coef_2 = torch.clamp(coef_1, max=eps_high).detach()
    # CISPO loss includes per_token_logps
    per_token_loss = -coef_2 * advantages.unsqueeze(1) * per_token_logps
    per_token_loss = per_token_loss * completion_mask if completion_mask is not None else per_token_loss

    per_token_kl = None
    if beta != 0.0:
        per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
        if use_bias_correction_kl:
            per_token_kl = per_token_kl * torch.exp(per_token_logps - old_logp)
        if completion_mask is not None:
            per_token_kl *= completion_mask
        per_token_loss = per_token_loss + beta * per_token_kl
    is_clipped = ((coef_1 > eps_high) & (advantages.unsqueeze(1) > 0)).float()
    return per_token_loss, per_token_kl, is_clipped


def torch_sapo_loss(
    logits,
    old_logp,
    ref_logp,
    completion_ids,
    advantages,
    completion_mask,
    temperature,
    beta,
    sapo_temperature_pos,
    sapo_temperature_neg,
    use_bias_correction_kl=False,
):
    """Reference implementation for SAPO loss.

    SAPO (Soft Adaptive Policy Optimization) uses:
    - Sigmoid-based soft gating instead of hard clipping
    - Different temperatures for positive/negative advantages

    Reference: https://huggingface.co/papers/2511.20347
    """
    assert logits.is_contiguous() and completion_ids.is_contiguous()
    assert old_logp is None or old_logp.is_contiguous()
    assert (ref_logp is not None and ref_logp.is_contiguous()) if beta != 0.0 else True
    logits = logits[:, :-1]

    per_token_logps = _get_log_probs(logits / temperature, completion_ids)
    ref_per_token_logps = ref_logp

    if old_logp is None:
        old_logp = per_token_logps.detach()
    coef_1 = torch.exp(per_token_logps - old_logp)

    # SAPO: sigmoid-based soft gating
    # Select temperature based on advantage sign
    temp = torch.where(advantages.unsqueeze(1) > 0, sapo_temperature_pos, sapo_temperature_neg)
    sigmoid_input = temp * (coef_1 - 1.0)
    sapo_coef = torch.sigmoid(sigmoid_input) * 4.0 / temp
    per_token_loss = -sapo_coef * advantages.unsqueeze(1)
    per_token_loss = per_token_loss * completion_mask if completion_mask is not None else per_token_loss

    per_token_kl = None
    if beta != 0.0:
        per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
        if use_bias_correction_kl:
            per_token_kl = per_token_kl * torch.exp(per_token_logps - old_logp)
        if completion_mask is not None:
            per_token_kl *= completion_mask
        per_token_loss = per_token_loss + beta * per_token_kl
    # SAPO has no clipping concept
    is_clipped = torch.zeros_like(per_token_loss)
    return per_token_loss, per_token_kl, is_clipped


set_seed(42)
device = infer_device()


@pytest.mark.parametrize(
    "temperature, B, T, V",
    [
        (0.9, 1, 1024, 64000),
        (0.7, 1, 1024, 151936),
    ],
)
@pytest.mark.parametrize(
    "dtype, atol, rtol",
    [
        (torch.bfloat16, 5e-2, 5e-1),
    ],
)
def test_selective_log_softmax(B, T, V, temperature, dtype, atol, rtol):
    # logits_to_keep + 1
    _input = torch.randn(B, T + 1, V, device=device, dtype=dtype)

    logit1 = _input.clone()
    logit2 = _input.clone()
    logit3 = _input.clone().float()

    # we set the length of prompt_ids is 100 and the length of completion_ids is T
    input_ids = torch.randint(0, V - 1, (B, 100 + T), dtype=torch.int64, device=device)

    torch_bf16_logp = selective_log_softmax(logit1, input_ids, temperature)
    triton_bf16_logp = fused_selective_log_softmax(logit2, input_ids, temperature)
    torch_fp32_logp = selective_log_softmax(logit3, input_ids, temperature)

    assert_verbose_allclose(torch_bf16_logp, torch_fp32_logp.to(dtype), rtol=rtol, atol=atol)
    assert_verbose_allclose(triton_bf16_logp, torch_fp32_logp.to(dtype), rtol=rtol, atol=atol)


@pytest.mark.parametrize(
    "temperature, num_iteration, beta, eps_low, eps_high",
    [(0.7, num_iteration, beta, 0.2, 0.4) for num_iteration in [1, 5] for beta in [0.0, 0.04]],
)
@pytest.mark.parametrize(
    "B, T, V",
    [
        (1, 1024, 151936),
    ],
)
@pytest.mark.parametrize(
    "dtype, atol, rtol",
    [
        (torch.bfloat16, 5e-2, 5e-1),
    ],
)
def test_grpo_loss(B, T, V, temperature, num_iteration, beta, eps_low, eps_high, dtype, atol, rtol):
    _input = torch.randn(B, T + 1, V, device=device, dtype=dtype)

    logits1 = _input.clone().requires_grad_(True)
    logits2 = _input.clone().requires_grad_(True)
    logits3 = _input.clone().float().requires_grad_(True)

    completion_ids = torch.randint(0, V - 1, (B, T), dtype=torch.int64, device=device)
    completion_mask = torch.ones_like(completion_ids, dtype=torch.int32)
    # we set num_padding is 100
    completion_mask[:, -100:] = 0

    # we set these in fp32, because fused_selective_log_softmax retutn fp32 logp, although logits in bf16
    ref_logp = torch.randn(B, T, device=device, dtype=torch.float32) if beta != 0.0 else None
    old_logp = torch.randn(B, T, device=device, dtype=torch.float32) if num_iteration > 1 else None
    advantages = torch.randn(B, device=device, dtype=torch.float32)

    loss1, kl1, is_clipped1 = torch_grpo_loss(
        logits1, old_logp, ref_logp, completion_ids, advantages, completion_mask, temperature, beta, eps_low, eps_high
    )

    loss2, kl2, is_clipped2 = triton_grpo_loss(
        logits2,
        old_logp,
        ref_logp,
        completion_ids,
        advantages,
        completion_mask,
        temperature,
        beta,
        eps_low,
        eps_high,
        inplace=True,
    )

    loss3, kl3, is_clipped3 = torch_grpo_loss(
        logits3, old_logp, ref_logp, completion_ids, advantages, completion_mask, temperature, beta, eps_low, eps_high
    )

    dy = torch.randn_like(loss3)
    loss1.backward(dy)
    loss2.backward(dy)
    loss3.backward(dy)

    assert_verbose_allclose(loss1, loss3, atol=atol, rtol=rtol)
    if kl1 is not None and kl3 is not None:
        assert_verbose_allclose(kl1, kl3, atol=atol, rtol=rtol)
    assert_verbose_allclose(logits1.grad, logits3.grad, atol=atol, rtol=rtol)
    assert_verbose_allclose(loss2, loss3, atol=atol, rtol=rtol)
    if kl2 is not None and kl3 is not None:
        assert_verbose_allclose(kl2, kl3, atol=atol, rtol=rtol)
    assert_verbose_allclose(logits2.grad, logits3.grad, atol=atol, rtol=rtol)


@pytest.mark.parametrize("delta", [1.5, 2.0])
@pytest.mark.parametrize(
    "temperature, num_iteration, beta, eps_low, eps_high",
    [(0.7, 5, beta, 0.2, 0.4) for beta in [0.0, 0.04]],
)
@pytest.mark.parametrize(
    "B, T, V",
    [
        (2, 128, 1000),
    ],
)
@pytest.mark.parametrize(
    "dtype, atol, rtol",
    [
        (torch.bfloat16, 5e-2, 5e-1),
    ],
)
def test_grpo_loss_with_delta(B, T, V, temperature, num_iteration, beta, eps_low, eps_high, dtype, atol, rtol, delta):
    """Test delta (two-sided clipping) support for standard PPO loss types."""
    _input = torch.randn(B, T + 1, V, device=device, dtype=dtype)

    logits1 = _input.clone().requires_grad_(True)
    logits2 = _input.clone().requires_grad_(True)
    logits3 = _input.clone().float().requires_grad_(True)

    completion_ids = torch.randint(0, V - 1, (B, T), dtype=torch.int64, device=device)
    completion_mask = torch.ones_like(completion_ids, dtype=torch.int32)
    completion_mask[:, -20:] = 0

    ref_logp = torch.randn(B, T, device=device, dtype=torch.float32) if beta != 0.0 else None
    old_logp = torch.randn(B, T, device=device, dtype=torch.float32) if num_iteration > 1 else None
    advantages = torch.randn(B, device=device, dtype=torch.float32)

    loss1, kl1, is_clipped1 = torch_grpo_loss(
        logits1,
        old_logp,
        ref_logp,
        completion_ids,
        advantages,
        completion_mask,
        temperature,
        beta,
        eps_low,
        eps_high,
        delta=delta,
    )

    loss2, kl2, is_clipped2 = triton_grpo_loss(
        logits2,
        old_logp,
        ref_logp,
        completion_ids,
        advantages,
        completion_mask,
        temperature,
        beta,
        eps_low,
        eps_high,
        inplace=True,
        delta=delta,
    )

    loss3, kl3, is_clipped3 = torch_grpo_loss(
        logits3,
        old_logp,
        ref_logp,
        completion_ids,
        advantages,
        completion_mask,
        temperature,
        beta,
        eps_low,
        eps_high,
        delta=delta,
    )

    dy = torch.randn_like(loss3)
    loss1.backward(dy)
    loss2.backward(dy)
    loss3.backward(dy)

    assert_verbose_allclose(loss1, loss3, atol=atol, rtol=rtol)
    if kl1 is not None and kl3 is not None:
        assert_verbose_allclose(kl1, kl3, atol=atol, rtol=rtol)
    assert_verbose_allclose(logits1.grad, logits3.grad, atol=atol, rtol=rtol)
    assert_verbose_allclose(loss2, loss3, atol=atol, rtol=rtol)
    if kl2 is not None and kl3 is not None:
        assert_verbose_allclose(kl2, kl3, atol=atol, rtol=rtol)
    assert_verbose_allclose(logits2.grad, logits3.grad, atol=atol, rtol=rtol)


@pytest.mark.parametrize(
    "temperature, num_iteration, eps_low, eps_high",
    [(0.7, 5, 0.2, 0.4)],
)
@pytest.mark.parametrize(
    "B, T, V",
    [
        (2, 128, 1000),
    ],
)
@pytest.mark.parametrize(
    "dtype, atol, rtol",
    [
        (torch.bfloat16, 5e-2, 5e-1),
    ],
)
def test_grpo_loss_with_bias_correction_kl(B, T, V, temperature, num_iteration, eps_low, eps_high, dtype, atol, rtol):
    """Test use_bias_correction_kl (importance-sampling-corrected KL from DeepSeek-V3.2)."""
    beta = 0.04  # Must be non-zero for KL to matter
    _input = torch.randn(B, T + 1, V, device=device, dtype=dtype)

    logits1 = _input.clone().requires_grad_(True)
    logits2 = _input.clone().requires_grad_(True)
    logits3 = _input.clone().float().requires_grad_(True)

    completion_ids = torch.randint(0, V - 1, (B, T), dtype=torch.int64, device=device)
    completion_mask = torch.ones_like(completion_ids, dtype=torch.int32)
    completion_mask[:, -20:] = 0

    ref_logp = torch.randn(B, T, device=device, dtype=torch.float32)
    old_logp = torch.randn(B, T, device=device, dtype=torch.float32)
    advantages = torch.randn(B, device=device, dtype=torch.float32)

    loss1, kl1, is_clipped1 = torch_grpo_loss(
        logits1,
        old_logp,
        ref_logp,
        completion_ids,
        advantages,
        completion_mask,
        temperature,
        beta,
        eps_low,
        eps_high,
        use_bias_correction_kl=True,
    )

    loss2, kl2, is_clipped2 = triton_grpo_loss(
        logits2,
        old_logp,
        ref_logp,
        completion_ids,
        advantages,
        completion_mask,
        temperature,
        beta,
        eps_low,
        eps_high,
        inplace=True,
        use_bias_correction_kl=True,
    )

    loss3, kl3, is_clipped3 = torch_grpo_loss(
        logits3,
        old_logp,
        ref_logp,
        completion_ids,
        advantages,
        completion_mask,
        temperature,
        beta,
        eps_low,
        eps_high,
        use_bias_correction_kl=True,
    )

    dy = torch.randn_like(loss3)
    loss1.backward(dy)
    loss2.backward(dy)
    loss3.backward(dy)

    assert_verbose_allclose(loss1, loss3, atol=atol, rtol=rtol)
    if kl1 is not None and kl3 is not None:
        assert_verbose_allclose(kl1, kl3, atol=atol, rtol=rtol)
    assert_verbose_allclose(logits1.grad, logits3.grad, atol=atol, rtol=rtol)
    assert_verbose_allclose(loss2, loss3, atol=atol, rtol=rtol)
    if kl2 is not None and kl3 is not None:
        assert_verbose_allclose(kl2, kl3, atol=atol, rtol=rtol)
    assert_verbose_allclose(logits2.grad, logits3.grad, atol=atol, rtol=rtol)


def trl_reference_grpo_loss(
    logits,
    old_logp,
    ref_logp,
    completion_ids,
    advantages,
    completion_mask,
    temperature,
    beta,
    eps_low,
    eps_high,
    loss_type,
    importance_sampling_level,
    delta=None,
    use_bias_correction_kl=False,
):
    """TRL reference implementation from grpo_trainer.py"""
    B, L_ADD_1, V = logits.shape
    L = L_ADD_1 - 1

    logits_scaled = logits[:, :-1, :] / temperature
    log_probs = torch.log_softmax(logits_scaled.float(), dim=-1)
    per_token_logps = log_probs.gather(-1, completion_ids.unsqueeze(-1)).squeeze(-1)

    if old_logp is None:
        old_logp = per_token_logps.detach()

    log_ratio = per_token_logps - old_logp

    if importance_sampling_level == "token":
        log_importance_weights = log_ratio
    else:  # sequence
        log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
        log_importance_weights = log_importance_weights.unsqueeze(-1)

    coef_1 = torch.exp(log_importance_weights)
    coef_2 = torch.clamp(coef_1, 1 - eps_low, 1 + eps_high)
    if delta is not None:
        coef_1 = torch.clamp(coef_1, max=delta)

    per_token_loss1 = coef_1 * advantages.unsqueeze(-1)
    per_token_loss2 = coef_2 * advantages.unsqueeze(-1)
    per_token_loss = -torch.min(per_token_loss1, per_token_loss2)

    if importance_sampling_level == "sequence":
        per_token_loss = per_token_loss.expand(B, L)

    if beta != 0.0:
        kl = torch.exp(ref_logp - per_token_logps) - (ref_logp - per_token_logps) - 1.0
        if use_bias_correction_kl:
            kl = kl * torch.exp(per_token_logps - old_logp)
        per_token_loss = per_token_loss + beta * kl

    # Loss reduction
    if loss_type == "grpo":
        loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean()
    elif loss_type == "bnpo":
        loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
    elif loss_type == "dr_grpo":
        loss = (per_token_loss * completion_mask).sum() / (B * L)
    elif loss_type == "dapo":
        loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
    elif loss_type == "luspo":
        loss = (per_token_loss * completion_mask.sum(-1, keepdim=True)).mean()

    return loss


@pytest.mark.parametrize("delta", [None, 1.5])
@pytest.mark.parametrize("importance_sampling_level", ["token", "sequence"])
@pytest.mark.parametrize("loss_type", ["grpo", "bnpo", "dr_grpo", "dapo", "luspo"])
@pytest.mark.parametrize("beta", [0.0, 0.04])
@pytest.mark.parametrize(
    "B, T, V",
    [
        (2, 128, 1000),
    ],
)
def test_grpo_loss_vs_trl(B, T, V, beta, loss_type, importance_sampling_level, delta):
    """Test that triton_grpo_loss matches TRL's exact implementation."""
    torch.manual_seed(42)

    logits = torch.randn(B, T + 1, V, device=device, dtype=torch.float32)
    completion_ids = torch.randint(0, V, (B, T), device=device)
    completion_mask = torch.ones(B, T, device=device, dtype=torch.float32)
    advantages = torch.randn(B, device=device, dtype=torch.float32)

    # Compute realistic old_logp and ref_logp
    with torch.no_grad():
        log_probs = torch.log_softmax(logits[:, :-1, :] / 0.9, dim=-1)
        current_logp = log_probs.gather(-1, completion_ids.unsqueeze(-1)).squeeze(-1)
        old_logp = current_logp + torch.randn_like(current_logp) * 0.3
        ref_logp = current_logp + torch.randn_like(current_logp) * 0.2 if beta != 0.0 else None

    temperature = 0.9
    eps_low, eps_high = 0.2, 0.4

    # TRL reference
    trl_loss = trl_reference_grpo_loss(
        logits.clone(),
        old_logp,
        ref_logp,
        completion_ids,
        advantages,
        completion_mask,
        temperature,
        beta,
        eps_low,
        eps_high,
        loss_type,
        importance_sampling_level,
        delta=delta,
    )

    # Triton implementation
    logits_triton = logits.clone().requires_grad_(True)
    triton_loss, _ = triton_grpo_loss(
        logits_triton,
        old_logp,
        ref_logp,
        completion_ids,
        advantages,
        completion_mask,
        temperature=temperature,
        beta=beta,
        eps_low=eps_low,
        eps_high=eps_high,
        importance_sampling_level=importance_sampling_level,
        loss_type=loss_type,
        max_completion_length=T,
        reduce=True,
        delta=delta,
    )

    # Verify forward match
    torch.testing.assert_close(triton_loss, trl_loss, rtol=1e-4, atol=1e-4)

    # Verify backward works
    triton_loss.backward()
    assert logits_triton.grad is not None
    assert not torch.isnan(logits_triton.grad).any()


def trl_reference_grpo_loss_with_vllm_is(
    logits,
    old_logp,
    ref_logp,
    completion_ids,
    advantages,
    completion_mask,
    temperature,
    beta,
    eps_low,
    eps_high,
    loss_type,
    importance_sampling_level,
    vllm_is_ratio,
    delta=None,
    use_bias_correction_kl=False,
):
    """TRL reference implementation with vLLM IS ratio correction."""
    B, L_ADD_1, V = logits.shape
    L = L_ADD_1 - 1

    logits_scaled = logits[:, :-1, :] / temperature
    log_probs = torch.log_softmax(logits_scaled.float(), dim=-1)
    per_token_logps = log_probs.gather(-1, completion_ids.unsqueeze(-1)).squeeze(-1)

    if old_logp is None:
        old_logp = per_token_logps.detach()

    log_ratio = per_token_logps - old_logp

    if importance_sampling_level == "token":
        log_importance_weights = log_ratio
    else:  # sequence
        log_importance_weights = (log_ratio * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)
        log_importance_weights = log_importance_weights.unsqueeze(-1)

    coef_1 = torch.exp(log_importance_weights)
    coef_2 = torch.clamp(coef_1, 1 - eps_low, 1 + eps_high)
    if delta is not None:
        coef_1 = torch.clamp(coef_1, max=delta)

    per_token_loss1 = coef_1 * advantages.unsqueeze(-1)
    per_token_loss2 = coef_2 * advantages.unsqueeze(-1)
    per_token_loss = -torch.min(per_token_loss1, per_token_loss2)

    if importance_sampling_level == "sequence":
        per_token_loss = per_token_loss.expand(B, L)

    # Apply vLLM IS ratio BEFORE KL penalty (matches TRL)
    if vllm_is_ratio is not None:
        per_token_loss = per_token_loss * vllm_is_ratio

    if beta != 0.0:
        kl = torch.exp(ref_logp - per_token_logps) - (ref_logp - per_token_logps) - 1.0
        if use_bias_correction_kl:
            kl = kl * torch.exp(per_token_logps - old_logp)
        per_token_loss = per_token_loss + beta * kl

    # Loss reduction
    if loss_type == "grpo":
        loss = ((per_token_loss * completion_mask).sum(-1) / completion_mask.sum(-1).clamp(min=1.0)).mean()
    elif loss_type == "bnpo":
        loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
    elif loss_type == "dr_grpo":
        loss = (per_token_loss * completion_mask).sum() / (B * L)
    elif loss_type == "dapo":
        loss = (per_token_loss * completion_mask).sum() / completion_mask.sum().clamp(min=1.0)
    elif loss_type == "luspo":
        loss = (per_token_loss * completion_mask.sum(-1, keepdim=True)).mean()

    return loss


def torch_grpo_loss_with_vllm_is(
    logits,
    old_logp,
    ref_logp,
    completion_ids,
    advantages,
    completion_mask,
    temperature,
    beta,
    eps_low,
    eps_high,
    vllm_is_ratio,
    loss_type="grpo",
    sapo_temperature_pos=1.0,
    sapo_temperature_neg=1.05,
    delta=None,
    use_bias_correction_kl=False,
):
    """Reference implementation with vLLM IS ratio correction for all loss types."""
    assert logits.is_contiguous() and completion_ids.is_contiguous()
    logits = logits[:, :-1]
    per_token_logps = _get_log_probs(logits / temperature, completion_ids)
    ref_per_token_logps = ref_logp
    if old_logp is None:
        old_logp = per_token_logps.detach()
    coef_1 = torch.exp(per_token_logps - old_logp)

    if loss_type == "cispo":
        coef_2 = torch.clamp(coef_1, max=eps_high).detach()
        per_token_loss = -coef_2 * advantages.unsqueeze(1) * per_token_logps
        is_clipped = ((coef_1 > eps_high) & (advantages.unsqueeze(1) > 0)).float()
    elif loss_type == "sapo":
        temp = torch.where(advantages.unsqueeze(1) > 0, sapo_temperature_pos, sapo_temperature_neg)
        sigmoid_input = temp * (coef_1 - 1.0)
        sapo_coef = torch.sigmoid(sigmoid_input) * 4.0 / temp
        per_token_loss = -sapo_coef * advantages.unsqueeze(1)
        is_clipped = torch.zeros_like(per_token_loss)
    else:
        coef_2 = torch.clamp(coef_1, 1 - eps_low, 1 + eps_high)
        if delta is not None:
            coef_1 = torch.clamp(coef_1, max=delta)
        per_token_loss1 = coef_1 * advantages.unsqueeze(1)
        per_token_loss2 = coef_2 * advantages.unsqueeze(1)
        per_token_loss = -torch.min(per_token_loss1, per_token_loss2)
        is_clipped = (per_token_loss1 < per_token_loss2).float()

    # Apply vLLM IS correction BEFORE KL penalty
    if vllm_is_ratio is not None:
        per_token_loss = per_token_loss * vllm_is_ratio
    per_token_loss = per_token_loss * completion_mask if completion_mask is not None else per_token_loss
    per_token_kl = None
    if beta != 0.0:
        per_token_kl = torch.exp(ref_per_token_logps - per_token_logps) - (ref_per_token_logps - per_token_logps) - 1
        if use_bias_correction_kl:
            per_token_kl = per_token_kl * torch.exp(per_token_logps - old_logp)
        if completion_mask is not None:
            per_token_kl *= completion_mask
        per_token_loss = per_token_loss + beta * per_token_kl
    return per_token_loss, per_token_kl, is_clipped


@pytest.mark.parametrize("importance_sampling_level", ["token", "sequence"])
@pytest.mark.parametrize("loss_type", ["grpo", "dapo", "luspo"])
@pytest.mark.parametrize("beta", [0.0, 0.04])
@pytest.mark.parametrize(
    "B, T, V",
    [
        (2, 128, 1000),
    ],
)
def test_grpo_loss_with_vllm_is_ratio_reduced(B, T, V, beta, loss_type, importance_sampling_level):
    """Test that triton_grpo_loss with vllm_is_ratio matches TRL's behavior with reduce=True."""
    torch.manual_seed(42)

    logits = torch.randn(B, T + 1, V, device=device, dtype=torch.float32)
    completion_ids = torch.randint(0, V, (B, T), device=device)
    completion_mask = torch.ones(B, T, device=device, dtype=torch.float32)
    advantages = torch.randn(B, device=device, dtype=torch.float32)

    # Compute realistic old_logp and ref_logp
    with torch.no_grad():
        log_probs = torch.log_softmax(logits[:, :-1, :] / 0.9, dim=-1)
        current_logp = log_probs.gather(-1, completion_ids.unsqueeze(-1)).squeeze(-1)
        old_logp = current_logp + torch.randn_like(current_logp) * 0.3
        ref_logp = current_logp + torch.randn_like(current_logp) * 0.2 if beta != 0.0 else None

    # Create vLLM IS ratio (random values between 0.5 and 1.5)
    vllm_is_ratio = torch.rand(B, T, device=device, dtype=torch.float32) + 0.5

    temperature = 0.9
    eps_low, eps_high = 0.2, 0.4

    # TRL reference with vLLM IS ratio
    trl_loss = trl_reference_grpo_loss_with_vllm_is(
        logits.clone(),
        old_logp,
        ref_logp,
        completion_ids,
        advantages,
        completion_mask,
        temperature,
        beta,
        eps_low,
        eps_high,
        loss_type,
        importance_sampling_level,
        vllm_is_ratio,
    )

    # Triton implementation with vLLM IS ratio
    logits_triton = logits.clone().requires_grad_(True)
    triton_loss, _ = triton_grpo_loss(
        logits_triton,
        old_logp,
        ref_logp,
        completion_ids,
        advantages,
        completion_mask,
        temperature=temperature,
        beta=beta,
        eps_low=eps_low,
        eps_high=eps_high,
        importance_sampling_level=importance_sampling_level,
        loss_type=loss_type,
        max_completion_length=T,
        reduce=True,
        vllm_is_ratio=vllm_is_ratio,
    )

    # Verify forward match
    torch.testing.assert_close(triton_loss, trl_loss, rtol=1e-4, atol=1e-4)

    # Verify backward works
    triton_loss.backward()
    assert logits_triton.grad is not None
    assert not torch.isnan(logits_triton.grad).any()

    # Also verify that vllm_is_ratio=None gives same result as vllm_is_ratio=1
    logits_no_ratio = logits.clone().requires_grad_(True)
    loss_no_ratio, _ = triton_grpo_loss(
        logits_no_ratio,
        old_logp,
        ref_logp,
        completion_ids,
        advantages,
        completion_mask,
        temperature=temperature,
        beta=beta,
        eps_low=eps_low,
        eps_high=eps_high,
        importance_sampling_level=importance_sampling_level,
        loss_type=loss_type,
        max_completion_length=T,
        reduce=True,
        vllm_is_ratio=None,
    )

    logits_ones_ratio = logits.clone().requires_grad_(True)
    loss_ones_ratio, _ = triton_grpo_loss(
        logits_ones_ratio,
        old_logp,
        ref_logp,
        completion_ids,
        advantages,
        completion_mask,
        temperature=temperature,
        beta=beta,
        eps_low=eps_low,
        eps_high=eps_high,
        importance_sampling_level=importance_sampling_level,
        loss_type=loss_type,
        max_completion_length=T,
        reduce=True,
        vllm_is_ratio=torch.ones(B, T, device=device),
    )

    torch.testing.assert_close(loss_no_ratio, loss_ones_ratio, rtol=1e-5, atol=1e-5)


@pytest.mark.parametrize(
    "temperature, num_iteration, beta, eps_low, eps_high",
    [(0.7, num_iteration, beta, 0.2, 0.4) for num_iteration in [1, 5] for beta in [0.0, 0.04]],
)
@pytest.mark.parametrize(
    "B, T, V",
    [
        (2, 128, 1000),
    ],
)
@pytest.mark.parametrize(
    "dtype, atol, rtol",
    [
        (torch.bfloat16, 5e-2, 5e-1),
    ],
)
@pytest.mark.parametrize("loss_type", ["grpo", "cispo", "sapo"])
def test_grpo_loss_with_vllm_is_ratio(
    B, T, V, temperature, num_iteration, beta, eps_low, eps_high, dtype, atol, rtol, loss_type
):
    """Test that triton_grpo_loss with vllm_is_ratio matches PyTorch reference for all loss types."""
    _input = torch.randn(B, T + 1, V, device=device, dtype=dtype)

    logits1 = _input.clone().requires_grad_(True)
    logits2 = _input.clone().requires_grad_(True)
    logits3 = _input.clone().float().requires_grad_(True)

    completion_ids = torch.randint(0, V - 1, (B, T), dtype=torch.int64, device=device)
    completion_mask = torch.ones_like(completion_ids, dtype=torch.int32)
    completion_mask[:, -20:] = 0

    ref_logp = torch.randn(B, T, device=device, dtype=torch.float32) if beta != 0.0 else None
    old_logp = torch.randn(B, T, device=device, dtype=torch.float32) if num_iteration > 1 else None
    advantages = torch.randn(B, device=device, dtype=torch.float32)

    # Create vLLM IS ratio (random values between 0.001 and 1.0 to simulate typical IS correction)
    vllm_is_ratio = torch.rand(B, T, device=device, dtype=torch.float32) * 0.999 + 0.001

    loss1, kl1, _ = torch_grpo_loss_with_vllm_is(
        logits1,
        old_logp,
        ref_logp,
        completion_ids,
        advantages,
        completion_mask,
        temperature,
        beta,
        eps_low,
        eps_high,
        vllm_is_ratio,
        loss_type=loss_type,
    )
    loss2, kl2, _ = triton_grpo_loss(
        logits2,
        old_logp,
        ref_logp,
        completion_ids,
        advantages,
        completion_mask,
        temperature,
        beta,
        eps_low,
        eps_high,
        inplace=True,
        vllm_is_ratio=vllm_is_ratio,
        loss_type=loss_type,
    )
    loss3, kl3, _ = torch_grpo_loss_with_vllm_is(
        logits3,
        old_logp,
        ref_logp,
        completion_ids,
        advantages,
        completion_mask,
        temperature,
        beta,
        eps_low,
        eps_high,
        vllm_is_ratio,
        loss_type=loss_type,
    )

    dy = torch.randn_like(loss3)
    loss1.backward(dy)
    loss2.backward(dy)
    loss3.backward(dy)

    # Compare triton bf16 vs torch fp32
    assert_verbose_allclose(loss2, loss3, atol=atol, rtol=rtol)
    if kl2 is not None and kl3 is not None:
        assert_verbose_allclose(kl2, kl3, atol=atol, rtol=rtol)
    assert_verbose_allclose(logits2.grad, logits3.grad, atol=atol, rtol=rtol)

    # Verify vllm_is_ratio=None gives same result as vllm_is_ratio=ones
    logits_none = _input.clone().float().requires_grad_(True)
    logits_ones = _input.clone().float().requires_grad_(True)
    loss_none, _, _ = triton_grpo_loss(
        logits_none,
        old_logp,
        ref_logp,
        completion_ids,
        advantages,
        completion_mask,
        temperature,
        beta,
        eps_low,
        eps_high,
        inplace=False,
        vllm_is_ratio=None,
        loss_type=loss_type,
    )
    loss_ones, _, _ = triton_grpo_loss(
        logits_ones,
        old_logp,
        ref_logp,
        completion_ids,
        advantages,
        completion_mask,
        temperature,
        beta,
        eps_low,
        eps_high,
        inplace=False,
        vllm_is_ratio=torch.ones(B, T, device=device, dtype=torch.float32),
        loss_type=loss_type,
    )
    assert_verbose_allclose(loss_none, loss_ones, atol=1e-5, rtol=1e-5)

    # Verify (B, 1) shape gives same result as (B, T) with uniform value
    uniform_val = 0.42
    logits_b1 = _input.clone().float().requires_grad_(True)
    logits_bt = _input.clone().float().requires_grad_(True)
    loss_b1, _, _ = triton_grpo_loss(
        logits_b1,
        old_logp,
        ref_logp,
        completion_ids,
        advantages,
        completion_mask,
        temperature,
        beta,
        eps_low,
        eps_high,
        inplace=False,
        vllm_is_ratio=torch.full((B, 1), uniform_val, device=device, dtype=torch.float32),
        loss_type=loss_type,
    )
    loss_bt, _, _ = triton_grpo_loss(
        logits_bt,
        old_logp,
        ref_logp,
        completion_ids,
        advantages,
        completion_mask,
        temperature,
        beta,
        eps_low,
        eps_high,
        inplace=False,
        vllm_is_ratio=torch.full((B, T), uniform_val, device=device, dtype=torch.float32),
        loss_type=loss_type,
    )
    loss_b1.backward(dy)
    loss_bt.backward(dy)
    assert_verbose_allclose(loss_b1, loss_bt, atol=1e-5, rtol=1e-5)
    assert_verbose_allclose(logits_b1.grad, logits_bt.grad, atol=1e-5, rtol=1e-5)

    # Verify 1D (B,) shape gives same result as (B, 1)
    logits_1d = _input.clone().float().requires_grad_(True)
    logits_2d = _input.clone().float().requires_grad_(True)
    loss_1d, _, _ = triton_grpo_loss(
        logits_1d,
        old_logp,
        ref_logp,
        completion_ids,
        advantages,
        completion_mask,
        temperature,
        beta,
        eps_low,
        eps_high,
        inplace=False,
        vllm_is_ratio=torch.full((B,), uniform_val, device=device, dtype=torch.float32),
        loss_type=loss_type,
    )
    loss_2d, _, _ = triton_grpo_loss(
        logits_2d,
        old_logp,
        ref_logp,
        completion_ids,
        advantages,
        completion_mask,
        temperature,
        beta,
        eps_low,
        eps_high,
        inplace=False,
        vllm_is_ratio=torch.full((B, 1), uniform_val, device=device, dtype=torch.float32),
        loss_type=loss_type,
    )
    loss_1d.backward(dy)
    loss_2d.backward(dy)
    assert_verbose_allclose(loss_1d, loss_2d, atol=1e-5, rtol=1e-5)
    assert_verbose_allclose(logits_1d.grad, logits_2d.grad, atol=1e-5, rtol=1e-5)


@pytest.mark.parametrize("beta", [0.0, 0.04])
def test_grpo_loss_sequence_backward_matches_reference(beta):
    """Sequence-level importance sampling should match reference gradients."""
    pytest.importorskip("triton")
    torch.manual_seed(0)

    B, T, V = 2, 8, 32
    logits = torch.randn(B, T + 1, V, device=device, dtype=torch.float32)
    completion_ids = torch.randint(0, V, (B, T), device=device)
    completion_mask = torch.ones(B, T, device=device, dtype=torch.float32)
    advantages = torch.randn(B, device=device, dtype=torch.float32)

    with torch.no_grad():
        log_probs = torch.log_softmax(logits[:, :-1, :] / 1.1, dim=-1)
        current_logp = log_probs.gather(-1, completion_ids.unsqueeze(-1)).squeeze(-1)
        old_logp = current_logp + torch.randn_like(current_logp) * 0.2
        ref_logp = current_logp + torch.randn_like(current_logp) * 0.1 if beta != 0.0 else None

    temperature = 1.1
    eps_low, eps_high = 0.2, 0.4

    logits_triton = logits.clone().requires_grad_(True)
    triton_loss, _ = triton_grpo_loss(
        logits_triton,
        old_logp,
        ref_logp,
        completion_ids,
        advantages,
        completion_mask,
        temperature=temperature,
        beta=beta,
        eps_low=eps_low,
        eps_high=eps_high,
        importance_sampling_level="sequence",
        loss_type="grpo",
        max_completion_length=T,
        reduce=True,
    )
    triton_loss.backward()

    logits_ref = logits.clone().requires_grad_(True)
    reference_loss = trl_reference_grpo_loss(
        logits_ref,
        old_logp,
        ref_logp,
        completion_ids,
        advantages,
        completion_mask,
        temperature,
        beta,
        eps_low,
        eps_high,
        loss_type="grpo",
        importance_sampling_level="sequence",
    )
    reference_loss.backward()

    torch.testing.assert_close(triton_loss, reference_loss, rtol=1e-5, atol=1e-5)
    torch.testing.assert_close(logits_triton.grad, logits_ref.grad, rtol=1e-4, atol=1e-4)


@pytest.mark.parametrize(
    "temperature, num_iteration, beta, eps_high",
    [(0.7, num_iteration, beta, 5.0) for num_iteration in [1, 5] for beta in [0.0, 0.04]],
)
@pytest.mark.parametrize(
    "B, T, V",
    [
        (1, 1024, 151936),
    ],
)
@pytest.mark.parametrize(
    "dtype, atol, rtol",
    [
        (torch.bfloat16, 5e-2, 5e-1),
    ],
)
def test_cispo_loss(B, T, V, temperature, num_iteration, beta, eps_high, dtype, atol, rtol):
    """Test CISPO loss type support in Triton kernel."""
    _input = torch.randn(B, T + 1, V, device=device, dtype=dtype)

    logits1 = _input.clone().requires_grad_(True)
    logits2 = _input.clone().requires_grad_(True)
    logits3 = _input.clone().float().requires_grad_(True)

    completion_ids = torch.randint(0, V - 1, (B, T), dtype=torch.int64, device=device)
    completion_mask = torch.ones_like(completion_ids, dtype=torch.int32)
    completion_mask[:, -100:] = 0

    ref_logp = torch.randn(B, T, device=device, dtype=torch.float32) if beta != 0.0 else None
    old_logp = torch.randn(B, T, device=device, dtype=torch.float32) if num_iteration > 1 else None
    advantages = torch.randn(B, device=device, dtype=torch.float32)

    loss1, kl1, is_clipped1 = torch_cispo_loss(
        logits1, old_logp, ref_logp, completion_ids, advantages, completion_mask, temperature, beta, eps_high
    )

    loss2, kl2, is_clipped2 = triton_grpo_loss(
        logits2,
        old_logp,
        ref_logp,
        completion_ids,
        advantages,
        completion_mask,
        temperature,
        beta,
        eps_low=0.2,  # not used for CISPO
        eps_high=eps_high,
        inplace=True,
        loss_type="cispo",
    )

    loss3, kl3, is_clipped3 = torch_cispo_loss(
        logits3, old_logp, ref_logp, completion_ids, advantages, completion_mask, temperature, beta, eps_high
    )

    dy = torch.randn_like(loss3)
    loss1.backward(dy)
    loss2.backward(dy)
    loss3.backward(dy)

    assert_verbose_allclose(loss1, loss3, atol=atol, rtol=rtol)
    if kl1 is not None and kl3 is not None:
        assert_verbose_allclose(kl1, kl3, atol=atol, rtol=rtol)
    assert_verbose_allclose(logits1.grad, logits3.grad, atol=atol, rtol=rtol)
    assert_verbose_allclose(loss2, loss3, atol=atol, rtol=rtol)
    if kl2 is not None and kl3 is not None:
        assert_verbose_allclose(kl2, kl3, atol=atol, rtol=rtol)
    assert_verbose_allclose(logits2.grad, logits3.grad, atol=atol, rtol=rtol)


@pytest.mark.parametrize(
    "temperature, num_iteration, beta, sapo_temp_pos, sapo_temp_neg",
    [(0.7, num_iteration, beta, 1.0, 1.05) for num_iteration in [1, 5] for beta in [0.0, 0.04]],
)
@pytest.mark.parametrize(
    "B, T, V",
    [
        (1, 1024, 151936),
    ],
)
@pytest.mark.parametrize(
    "dtype, atol, rtol",
    [
        (torch.bfloat16, 5e-2, 5e-1),
    ],
)
def test_sapo_loss(B, T, V, temperature, num_iteration, beta, sapo_temp_pos, sapo_temp_neg, dtype, atol, rtol):
    """Test SAPO loss type support in Triton kernel."""
    _input = torch.randn(B, T + 1, V, device=device, dtype=dtype)

    logits1 = _input.clone().requires_grad_(True)
    logits2 = _input.clone().requires_grad_(True)
    logits3 = _input.clone().float().requires_grad_(True)

    completion_ids = torch.randint(0, V - 1, (B, T), dtype=torch.int64, device=device)
    completion_mask = torch.ones_like(completion_ids, dtype=torch.int32)
    completion_mask[:, -100:] = 0

    ref_logp = torch.randn(B, T, device=device, dtype=torch.float32) if beta != 0.0 else None
    old_logp = torch.randn(B, T, device=device, dtype=torch.float32) if num_iteration > 1 else None
    advantages = torch.randn(B, device=device, dtype=torch.float32)

    loss1, kl1, is_clipped1 = torch_sapo_loss(
        logits1,
        old_logp,
        ref_logp,
        completion_ids,
        advantages,
        completion_mask,
        temperature,
        beta,
        sapo_temp_pos,
        sapo_temp_neg,
    )

    loss2, kl2, is_clipped2 = triton_grpo_loss(
        logits2,
        old_logp,
        ref_logp,
        completion_ids,
        advantages,
        completion_mask,
        temperature,
        beta,
        eps_low=0.2,  # not used for SAPO
        eps_high=0.4,  # not used for SAPO
        inplace=True,
        loss_type="sapo",
        sapo_temperature_pos=sapo_temp_pos,
        sapo_temperature_neg=sapo_temp_neg,
    )

    loss3, kl3, is_clipped3 = torch_sapo_loss(
        logits3,
        old_logp,
        ref_logp,
        completion_ids,
        advantages,
        completion_mask,
        temperature,
        beta,
        sapo_temp_pos,
        sapo_temp_neg,
    )

    dy = torch.randn_like(loss3)
    loss1.backward(dy)
    loss2.backward(dy)
    loss3.backward(dy)

    assert_verbose_allclose(loss1, loss3, atol=atol, rtol=rtol)
    if kl1 is not None and kl3 is not None:
        assert_verbose_allclose(kl1, kl3, atol=atol, rtol=rtol)
    assert_verbose_allclose(logits1.grad, logits3.grad, atol=atol, rtol=rtol)
    assert_verbose_allclose(loss2, loss3, atol=atol, rtol=rtol)
    if kl2 is not None and kl3 is not None:
        assert_verbose_allclose(kl2, kl3, atol=atol, rtol=rtol)
    assert_verbose_allclose(logits2.grad, logits3.grad, atol=atol, rtol=rtol)


@pytest.mark.parametrize("loss_type", ["cispo", "sapo"])
def test_triton_sequence_level_rejects_unsupported_loss_types(loss_type):
    """Sequence-level importance sampling should raise ValueError for cispo and sapo."""
    B, T, V = 2, 8, 32
    logits = torch.randn(B, T + 1, V, device=device, dtype=torch.float32).contiguous()
    completion_ids = torch.randint(0, V, (B, T), device=device)
    completion_mask = torch.ones(B, T, device=device, dtype=torch.float32)
    advantages = torch.randn(B, device=device, dtype=torch.float32)
    old_logp = torch.randn(B, T, device=device, dtype=torch.float32)

    with pytest.raises(ValueError, match="Sequence-level importance sampling is not supported"):
        triton_grpo_loss(
            logits,
            old_logp,
            None,
            completion_ids,
            advantages,
            completion_mask,
            temperature=0.9,
            beta=0.0,
            eps_low=0.2,
            eps_high=0.4,
            importance_sampling_level="sequence",
            loss_type=loss_type,
            reduce=True,
        )