nemo_2.3.0_te.patch 103 KB
Newer Older
jerrrrry's avatar
jerrrrry committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
1931
1932
1933
1934
1935
1936
1937
1938
1939
1940
1941
1942
1943
1944
1945
1946
1947
1948
1949
1950
1951
1952
1953
1954
1955
1956
1957
1958
1959
1960
1961
1962
1963
1964
1965
1966
1967
1968
1969
1970
1971
1972
1973
1974
1975
1976
1977
1978
1979
1980
1981
1982
1983
1984
1985
1986
1987
1988
1989
1990
1991
1992
1993
1994
1995
1996
1997
1998
1999
2000
2001
2002
2003
2004
2005
2006
2007
2008
2009
2010
2011
2012
2013
2014
2015
2016
2017
2018
2019
2020
2021
2022
2023
2024
2025
2026
2027
2028
2029
2030
2031
2032
2033
2034
2035
2036
2037
2038
2039
2040
2041
2042
2043
2044
2045
2046
2047
2048
2049
2050
2051
2052
2053
2054
2055
2056
2057
2058
2059
2060
2061
2062
2063
2064
2065
2066
2067
2068
2069
2070
2071
2072
2073
2074
2075
2076
2077
2078
2079
2080
2081
2082
2083
2084
2085
2086
2087
2088
2089
2090
2091
2092
2093
2094
2095
2096
2097
2098
2099
2100
2101
2102
2103
2104
2105
2106
2107
2108
2109
2110
2111
2112
2113
2114
2115
2116
2117
2118
2119
2120
2121
2122
2123
2124
2125
2126
2127
2128
2129
2130
2131
2132
2133
2134
2135
2136
2137
2138
2139
2140
2141
2142
2143
2144
2145
2146
2147
2148
2149
2150
2151
2152
2153
2154
2155
2156
2157
2158
2159
2160
2161
2162
2163
2164
2165
2166
2167
2168
2169
2170
2171
2172
2173
2174
2175
2176
2177
2178
2179
2180
2181
2182
2183
2184
2185
2186
2187
2188
2189
2190
2191
2192
2193
2194
2195
2196
2197
2198
2199
2200
2201
2202
2203
2204
2205
2206
2207
2208
2209
2210
2211
2212
2213
2214
2215
2216
2217
2218
2219
2220
2221
2222
2223
2224
2225
2226
2227
2228
2229
2230
2231
2232
2233
2234
2235
2236
2237
2238
2239
2240
2241
2242
2243
2244
2245
2246
2247
2248
2249
2250
2251
2252
2253
2254
2255
2256
2257
2258
2259
2260
2261
2262
2263
2264
2265
2266
2267
2268
2269
2270
2271
2272
2273
2274
2275
2276
2277
2278
2279
2280
2281
2282
2283
2284
2285
2286
2287
2288
2289
2290
2291
2292
2293
2294
2295
2296
2297
2298
2299
2300
2301
2302
2303
2304
2305
2306
2307
2308
2309
2310
2311
2312
2313
2314
2315
2316
2317
2318
2319
2320
2321
2322
2323
2324
2325
2326
2327
2328
2329
2330
2331
2332
2333
2334
2335
2336
2337
2338
2339
2340
2341
2342
2343
2344
2345
2346
2347
2348
2349
diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh
index 732f0a1..29f40bb 100644
--- a/qa/L0_pytorch_unittest/test.sh
+++ b/qa/L0_pytorch_unittest/test.sh
@@ -39,6 +39,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py ||
 python3 -m pytest -v -s $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py"
 NVTE_DEBUG=1 NVTE_DEBUG_LEVEL=1 python3 -m pytest -o log_cli=true --log-cli-level=INFO -v -s $TE_PATH/tests/pytorch/fused_attn/test_paged_attn.py || test_fail "test_paged_attn.py"
+
 if [ "$RET" -ne 0 ]; then
     echo "Error in the following test cases:$FAILED_CASES"
     exit 1
diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh
index 5776734..36d491e 100644
--- a/qa/L1_pytorch_distributed_unittest/test.sh
+++ b/qa/L1_pytorch_distributed_unittest/test.sh
@@ -26,6 +26,7 @@ python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py |
 python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
 # python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" ### TODO Debug UB support with te.Sequential
 python3 -m pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py"
+python3 -m pytest -v -s $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"
 
 if [ "$RET" -ne 0 ]; then
     echo "Error in the following test cases:$FAILED_CASES"
diff --git a/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py
new file mode 100644
index 0000000..1b38f72
--- /dev/null
+++ b/tests/pytorch/distributed/run_cast_master_weights_to_fp8.py
@@ -0,0 +1,671 @@
+#!/usr/bin/python3
+
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import argparse
+import datetime
+import os
+import sys
+
+import torch
+from torch import nn
+import torch.distributed as dist
+
+from transformer_engine.common.recipe import (
+    DelayedScaling,
+    Float8CurrentScaling,
+    Format,
+    Recipe,
+)
+import transformer_engine.pytorch as te
+from transformer_engine.pytorch.tensor import QuantizedTensor, cast_master_weights_to_fp8
+from transformer_engine.pytorch.tensor.float8_tensor import (
+    Float8Tensor,
+    Float8CurrentScalingQuantizer,
+)
+from transformer_engine.pytorch.tensor.utils import replace_raw_data
+
+
+def _get_raw_data(quantized_tensor):
+    """Get the underlying data of a quantized tensor, used in zero-1 optimizer"""
+    if isinstance(quantized_tensor, Float8Tensor):
+        assert hasattr(quantized_tensor, "_data"), "Float8Tensor does not have _data attribute"
+        assert quantized_tensor._data.dtype == torch.uint8, "Float8Tensor _data must be uint8"
+        return quantized_tensor._data
+    else:
+        raise ValueError(f"Unsupported quantized tensor type: {type(quantized_tensor)}")
+
+
+class MiniZero_1:
+    """A mini zero-1 optimizer implementation, just used for this test"""
+
+    def __init__(self, weights, lr, dp_group):
+        self.rank = dist.get_rank(dp_group)
+        self.world_size = dist.get_world_size(dp_group)
+
+        self.weights = weights
+        self.lr = lr
+        self.dp_group = dp_group
+
+        # [self.offsets[i], self.offsets[i+1]) is the range of weights[i] in the global buffer
+        self.offsets = [0]
+        for weight in self.weights:
+            self.offsets.append(self.offsets[-1] + weight.numel())
+
+        # Padding to avoid global buffer cannot be divided by world size, so the offsets[-1] may
+        # not be the end range of the last weight.
+        if self.offsets[-1] % self.world_size != 0:
+            self.offsets[-1] += self.world_size - self.offsets[-1] % self.world_size
+
+        self.master_weights = []
+        # The start offset of the master weight in the weight
+        self.start_offsets = []
+        # The overlapping area of the weight and this rank's local buffer
+        self.overlapping_areas = []
+
+        # The start and end of this rank's local buffer in the global buffer
+        rank_start = self.offsets[-1] // self.world_size * self.rank
+        rank_end = rank_start + self.offsets[-1] // self.world_size
+
+        for weight, offset in zip(self.weights, self.offsets[:-1]):
+            if offset >= rank_end or (offset + weight.numel()) <= rank_start:
+                # This weight is not in this rank's local buffer
+                master_weight = None
+                start_offset = None
+                overlapping_area = None
+            else:
+                overlapping_start = max(rank_start, offset)
+                overlapping_end = min(rank_end, offset + weight.numel())
+                length = overlapping_end - overlapping_start
+                start_offset = overlapping_start - offset
+                if isinstance(weight, QuantizedTensor):
+                    # If weight is a FP8 tensor, we need to use the original high precision version
+                    # to initialize the master weight.
+                    high_precision_init_val = weight.get_high_precision_init_val().view(-1)
+                    master_weight = high_precision_init_val.to(weight.device).float()[
+                        start_offset : start_offset + length
+                    ]
+                else:
+                    master_weight = (
+                        weight.detach().view(-1).float()[start_offset : start_offset + length]
+                    )
+                overlapping_area = (overlapping_start, overlapping_end)
+            self.master_weights.append(master_weight)
+            self.start_offsets.append(start_offset)
+            self.overlapping_areas.append(overlapping_area)
+
+        # Create global buffer for grads reduce-scatter
+        self.grad_buffer = torch.empty(
+            [self.offsets[-1]], dtype=torch.float32, device=weights[0].device
+        )
+        self.grad_buffer_slice = self.grad_buffer[rank_start:rank_end]
+
+        # Create global buffer for weights all-gather
+        if isinstance(self.weights[0], QuantizedTensor):
+            weight_buffer_dtype = torch.uint8
+        else:
+            weight_buffer_dtype = weights[0].dtype
+        self.weight_buffer = torch.empty(
+            [self.offsets[-1]], dtype=weight_buffer_dtype, device=weights[0].device
+        )
+        self.weight_buffer_slice = self.weight_buffer[rank_start:rank_end]
+
+    def step(self):
+        # -----------------------------------------------------------------------------------------
+        # Step 1: Copy grads to the grad buffer
+        # -----------------------------------------------------------------------------------------
+        for weight, offset in zip(self.weights, self.offsets[:-1]):
+            start = offset
+            end = offset + weight.numel()
+            self.grad_buffer[start:end].copy_(weight.main_grad.view(-1))
+
+        # -----------------------------------------------------------------------------------------
+        # Step 2: Grads reduce-scatter
+        # -----------------------------------------------------------------------------------------
+        # Don't use reduce_scatter directly to explicitly control the reduce order.
+        # dist.reduce_scatter_tensor(self.grad_buffer_slice, self.grad_buffer, op=dist.ReduceOp.AVG,
+        #                            group=self.dp_group)
+        buffers = [torch.empty_like(self.grad_buffer) for _ in range(self.world_size)]
+        dist.all_gather(buffers, self.grad_buffer, group=self.dp_group)
+        for i in range(1, self.world_size):
+            buffers[0] += buffers[i]
+        rank_start = self.offsets[-1] // self.world_size * self.rank
+        rank_end = rank_start + self.offsets[-1] // self.world_size
+        self.grad_buffer_slice.copy_(buffers[0][rank_start:rank_end])
+        self.grad_buffer_slice /= self.world_size
+
+        # -----------------------------------------------------------------------------------------
+        # Step 3: Update master weights
+        # -----------------------------------------------------------------------------------------
+        for master_weight, overlapping_area in zip(self.master_weights, self.overlapping_areas):
+            if master_weight is None:
+                # This weight's master weight is in other rank.
+                continue
+            grad = self.grad_buffer[overlapping_area[0] : overlapping_area[1]]
+            master_weight -= grad * self.lr
+
+        # -----------------------------------------------------------------------------------------
+        # Step 4: Cast master weights to BF16 or FP8, depending on the type of the weight
+        # -----------------------------------------------------------------------------------------
+        if isinstance(self.weights[0], QuantizedTensor):
+            # FP8 weights case
+            for i in range(1, len(self.weights)):
+                assert isinstance(self.weights[i], QuantizedTensor)
+            cast_master_weights_to_fp8(
+                self.weights, self.master_weights, self.start_offsets, self.dp_group
+            )
+        else:
+            # BF16 weights case
+            for weight, master_weight, start_offset in zip(
+                self.weights, self.master_weights, self.start_offsets
+            ):
+                if master_weight is None:
+                    continue
+                start = start_offset
+                end = start_offset + master_weight.numel()
+                weight.data.view(-1)[start:end].copy_(master_weight)
+
+        # -----------------------------------------------------------------------------------------
+        # Step 5: Copy the updated weights (not all weights) to the weight buffer
+        # -----------------------------------------------------------------------------------------
+        for i in range(len(self.weights)):
+            master_weight = self.master_weights[i]
+            if master_weight is None:
+                continue
+            start_offset = self.start_offsets[i]
+            if isinstance(self.weights[i], QuantizedTensor):
+                weight = _get_raw_data(self.weights[i])
+            else:
+                weight = self.weights[i]
+            weight_slice = weight.view(-1)[start_offset : start_offset + master_weight.numel()]
+            overlapping_start, overlapping_end = self.overlapping_areas[i]
+            self.weight_buffer[overlapping_start:overlapping_end].copy_(weight_slice)
+
+        # -----------------------------------------------------------------------------------------
+        # Step 6: Weight all-gather (FP8 or BF16)
+        # -----------------------------------------------------------------------------------------
+        dist.all_gather_into_tensor(
+            self.weight_buffer, self.weight_buffer_slice, group=self.dp_group
+        )
+
+        # -----------------------------------------------------------------------------------------
+        # Step 7: Copy the gathered weights from weight buffer to the actual weights
+        # -----------------------------------------------------------------------------------------
+        for weight, offset in zip(self.weights, self.offsets[:-1]):
+            start = offset
+            end = offset + weight.numel()
+            if isinstance(weight, QuantizedTensor):
+                weight = _get_raw_data(weight)
+            weight.view(-1).data.copy_(self.weight_buffer[start:end])
+
+
+class MiniOptimizer:
+
+    def __init__(self, weights, lr, dp_group):
+        self.world_size = dist.get_world_size(dp_group)
+
+        self.weights = weights
+        self.lr = lr
+        self.dp_group = dp_group
+
+        master_weights = []
+        for weight in self.weights:
+            master_weights.append(weight.detach().float())
+        self.master_weights = master_weights
+
+    def step(self):
+        for weight, master_weight in zip(self.weights, self.master_weights):
+            main_grad = weight.main_grad
+
+            # Don't use all-reduce directly to explicitly control the reduce order.
+            # dist.all_reduce(main_grad, op=dist.ReduceOp.AVG, group=self.dp_group)
+            buffers = [torch.empty_like(main_grad) for _ in range(self.world_size)]
+            dist.all_gather(buffers, main_grad, group=self.dp_group)
+            for i in range(1, self.world_size):
+                buffers[0] += buffers[i]
+            main_grad.copy_(buffers[0])
+            main_grad /= self.world_size
+
+            master_weight -= main_grad * self.lr
+            weight.data.copy_(master_weight)
+
+
+class MiniFSDP:
+    def __init__(self, weights, lr, dp_group):
+        rank = dist.get_rank(dp_group)
+        world_size = dist.get_world_size(dp_group)
+
+        self.weights = weights
+        self.lr = lr
+        self.dp_group = dp_group
+
+        # Flatten the weights and pad to align with world size
+        raw_data_list = [
+            _get_raw_data(w).view(-1) if isinstance(w, QuantizedTensor) else w.view(-1)
+            for w in weights
+        ]
+        if isinstance(weights[0], QuantizedTensor):
+            raw_data_list = [_get_raw_data(w).view(-1) for w in weights]
+        else:
+            raw_data_list = [w.view(-1) for w in weights]
+        self.flatten_weight, original_length = self._flatten_tensors_with_pad(raw_data_list)
+
+        # Split flattened weights into shards
+        self.local_weight_shard = torch.chunk(self.flatten_weight, world_size)[rank]
+        self.local_main_grad_shard = torch.zeros_like(self.local_weight_shard)
+        shard_size = self.flatten_weight.size(0) // world_size
+
+        # Map original tensors to flattened indices
+        tensor_indices = []
+        cumulative_length = 0
+        for tensor in raw_data_list:
+            length = tensor.size(0)
+            tensor_indices.append((cumulative_length, cumulative_length + length))
+            cumulative_length += length
+
+        # Build shard index mappings
+        self.weight_indices = []
+        self.shard_indices = []
+        for idx, (start, end) in enumerate(tensor_indices):
+            shard_start = rank * shard_size
+            shard_end = shard_start + shard_size
+            adjusted_end = min(shard_end, original_length)
+
+            if start <= adjusted_end and end >= shard_start:
+                start_idx = max(start, shard_start)
+                end_idx = min(end, adjusted_end)
+                self.weight_indices.append((start_idx - start, end_idx - start))
+                self.shard_indices.append((start_idx - shard_start, end_idx - shard_start))
+            else:
+                self.weight_indices.append((None, None))
+                self.shard_indices.append((None, None))
+
+            if isinstance(weights[idx], QuantizedTensor):
+                replace_raw_data(
+                    weights[idx], self.flatten_weight[start:end].view(weights[idx].shape)
+                )
+            else:
+                weights[idx].data = self.flatten_weight[start:end].view(weights[idx].shape)
+
+        # Initialize local model weights and high-precision master weights
+        self.local_weights = []
+        self.master_weights = []
+        for i, weight in enumerate(self.weights):
+            weight_start, weight_end = self.weight_indices[i]
+            shard_start, shard_end = self.shard_indices[i]
+            if shard_start is not None and shard_end is not None:
+                local_weight_shard = self.local_weight_shard[shard_start:shard_end]
+                self.local_weights.append(local_weight_shard)
+
+                if isinstance(weight, QuantizedTensor):
+                    high_precision_init_val = weight.get_high_precision_init_val().view(-1)
+                    master_weight_shard = high_precision_init_val.to(weight.device).float()[
+                        weight_start:weight_end
+                    ]
+                else:
+                    master_weight_shard = weight.detach().view(-1).float()[weight_start:weight_end]
+                self.master_weights.append(master_weight_shard)
+            else:
+                self.local_weights.append(None)
+                self.master_weights.append(None)
+            setattr(
+                weight, "main_grad", torch.zeros_like(weight, dtype=torch.float32, device="cuda")
+            )
+
+    def _flatten_tensors_with_pad(self, tensors):
+        """
+        Flatten the list of tensors and pad them to align with the world size.
+
+        Args:
+            tensors (list): List of tensors to flatten.
+
+        Returns:
+            tuple: Flattened tensor and its original length before padding.
+        """
+        world_size = dist.get_world_size(self.dp_group)
+
+        flatten_tensor = torch.cat(tensors)
+        original_length = flatten_tensor.size(0)
+
+        padding_needed = (world_size - original_length % world_size) % world_size
+        if padding_needed > 0:
+            flatten_tensor = torch.cat(
+                [flatten_tensor, torch.zeros(padding_needed, dtype=flatten_tensor.dtype)]
+            )
+
+        return flatten_tensor, original_length
+
+    def zero_grad(self):
+        for weight in self.weights:
+            weight.grad = None
+            weight.main_grad.zero_()
+
+    def step(self):
+        """
+        Perform an optimization step for the distributed sharded model.
+
+        This method includes:
+        1. Gradient reduce-scatter: Synchronize gradients across all processes.
+        2. Master weight update: Update high-precision master weights using local gradients.
+        3. Precision casting: Cast updated master weights to FP8 or BF16 precision.
+        4. Weight synchronization: All-gather updated weights across all processes.
+
+        Returns:
+            None
+        """
+        # Step 1: Reduce-scatter the gradients
+        main_grad_buffer, _ = self._flatten_tensors_with_pad(
+            [weight.main_grad.view(-1) for weight in self.weights]
+        )
+        main_grad_buffer = main_grad_buffer.to(self.local_main_grad_shard.dtype)
+        dist.reduce_scatter_tensor(
+            self.local_main_grad_shard, main_grad_buffer, group=self.dp_group
+        )
+
+        # Step 2: Update the master weights
+        for weight, master_weight, (shard_start, shard_end) in zip(
+            self.weights, self.master_weights, self.shard_indices
+        ):
+            if master_weight is None:
+                continue
+
+            # Extract the local gradient shard for this weight
+            grad = self.local_main_grad_shard[shard_start:shard_end]
+
+            # Update the master weight using gradient descent
+            master_weight -= grad * self.lr
+
+        # Step 3: Cast master weights to FP8 or BF16 precision
+        if isinstance(self.weights[0], QuantizedTensor):
+            local_weights = []
+            for local_weight in self.local_weights:
+                if local_weight is None:
+                    local_weights.append(None)
+                    continue
+
+                local_weights.append(local_weight)
+
+            cast_master_weights_to_fp8(
+                self.weights,
+                self.master_weights,
+                [idx[0] for idx in self.weight_indices],
+                self.dp_group,
+                local_weights,
+            )
+        else:
+            for weight, master_weight in zip(self.local_weights, self.master_weights):
+                if master_weight is None:
+                    continue
+
+                # Copy updated master weights to local weights
+                weight.data.copy_(master_weight)
+
+        # Step 4: All-gather updated weights across processes
+        dist.all_gather_into_tensor(
+            self.flatten_weight, self.local_weight_shard, group=self.dp_group
+        )
+
+
+def _test_fsdp_cast_master_weights_to_fp8(quantization, dp_group):
+    rank = dist.get_rank(dp_group)
+    world_size = dist.get_world_size(dp_group)
+
+    # Configuration constants
+    NUM_STEPS = 100
+    SEED = 12345
+
+    torch.manual_seed(SEED)
+    torch.cuda.manual_seed(SEED)
+
+    mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)]
+    mock_group = mock_groups[rank]
+
+    linear_kwargs = {
+        "params_dtype": torch.bfloat16,
+        "bias": False,
+        "fuse_wgrad_accumulation": False,
+    }
+
+    # Create model with FP8 weights
+    with te.fp8.fp8_model_init(
+        enabled=quantization is not None,
+        recipe=quantization_recipe(quantization),
+        preserve_high_precision_init_val=True,
+    ):
+        model_fp8 = nn.Sequential(
+            te.Linear(128, 256, **linear_kwargs),
+            te.Linear(256, 256 * 3, **linear_kwargs),
+            te.Linear(256 * 3, 128, **linear_kwargs),
+        )
+
+    # Create model with BF16 weights
+    model = nn.Sequential(
+        te.Linear(128, 256, **linear_kwargs),
+        te.Linear(256, 256 * 3, **linear_kwargs),
+        te.Linear(256 * 3, 128, **linear_kwargs),
+    )
+
+    # Make sure the BF16 model and FP8 model have the same initial weights
+    for w_fp8, w in zip(model_fp8.parameters(), model.parameters()):
+        high_precision_init_val = w_fp8.get_high_precision_init_val()
+        w.data.copy_(high_precision_init_val)
+
+    optimizer_fp8 = MiniFSDP([w for w in model_fp8.parameters()], 10.0, dp_group)
+    optimizer = MiniFSDP([w for w in model.parameters()], 10.0, dp_group)
+
+    for _ in range(100):
+        optimizer_fp8.zero_grad()
+        optimizer.zero_grad()
+
+        inputs = [
+            torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size)
+        ]
+        # Choose based on rank to make sure the inputs of different ranks are different.
+        x = inputs[rank]
+
+        with te.fp8.fp8_autocast(
+            enabled=quantization is not None,
+            fp8_recipe=quantization_recipe(quantization),
+            fp8_group=mock_group,
+        ):
+            y_fp8 = model_fp8(x)
+
+        with te.fp8_autocast(
+            enabled=quantization is not None,
+            fp8_recipe=quantization_recipe(quantization),
+            fp8_group=mock_group,
+        ):
+            y = model(x)
+
+        targets = [torch.randn_like(y) for _ in range(world_size)]
+        # Choose based on rank to make sure the targets of different ranks are different.
+        target = targets[rank]
+        loss_fp8 = nn.MSELoss()(y_fp8, target)
+        loss = nn.MSELoss()(y, target)
+
+        loss_fp8.backward()
+        loss.backward()
+
+        optimizer_fp8.step()
+        optimizer.step()
+
+        torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0)
+
+    print(
+        f"✅ Successfully validated FSDP {NUM_STEPS} training steps with"
+        f" {quantization} quantization"
+    )
+
+
+def _test_zero_1(dp_group):
+    """Make sure the implementation of zero-1 optimizer is correct"""
+    rank = dist.get_rank(dp_group)
+    world_size = dist.get_world_size(dp_group)
+
+    torch.manual_seed(12345)
+    torch.cuda.manual_seed(12345)
+
+    weights = [
+        torch.randn(256 * 256, dtype=torch.bfloat16, device="cuda"),
+        torch.randn(256 * 256 * 3, dtype=torch.bfloat16, device="cuda"),
+        torch.randn(256 * 256 * 2 - 1, dtype=torch.bfloat16, device="cuda"),
+    ]
+
+    weights_1 = weights
+    weights_2 = [weight.clone() for weight in weights]
+
+    lr = 1.0
+    optimizer_1 = MiniZero_1(weights_1, lr, dp_group)
+    optimizer_2 = MiniOptimizer(weights_2, lr, dp_group)
+
+    for _ in range(100):
+        for w1, w2 in zip(weights_1, weights_2):
+            main_grads = [
+                torch.randn_like(w1, dtype=torch.float32, device="cuda") for _ in range(world_size)
+            ]
+            # Choose based on rank to make sure the grads of different ranks are different.
+            main_grad = main_grads[rank]
+            w1.main_grad = main_grad
+            w2.main_grad = main_grad
+
+        optimizer_1.step()
+        optimizer_2.step()
+
+        for w1, w2 in zip(weights_1, weights_2):
+            torch.testing.assert_close(w1, w2, atol=0, rtol=0)
+
+
+def quantization_recipe(quantization) -> Recipe:
+    """Quantization recipe setup"""
+    if quantization == "fp8":
+        return DelayedScaling(
+            fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max"
+        )
+    elif quantization == "fp8_cs":
+        return Float8CurrentScaling()
+    else:
+        raise ValueError(f"Unsupported quantization: {quantization}")
+
+
+def _test_cast_master_weights_to_fp8(quantization, dp_group):
+    rank = dist.get_rank(dp_group)
+    world_size = dist.get_world_size(dp_group)
+
+    torch.manual_seed(12345)
+    torch.cuda.manual_seed(12345)
+
+    mock_groups = [dist.new_group(ranks=[i]) for i in range(world_size)]
+    mock_group = mock_groups[rank]
+
+    linear_kwargs = {"params_dtype": torch.bfloat16, "bias": False, "fuse_wgrad_accumulation": True}
+
+    # Create model with FP8 weights
+    with te.fp8.fp8_model_init(
+        enabled=quantization is not None,
+        recipe=quantization_recipe(quantization),
+        preserve_high_precision_init_val=True,
+    ):
+        model_fp8 = nn.Sequential(
+            te.Linear(128, 256, **linear_kwargs),
+            te.Linear(256, 256 * 3, **linear_kwargs),
+            te.Linear(256 * 3, 128, **linear_kwargs),
+        )
+
+    # Create model with BF16 weights
+    model = nn.Sequential(
+        te.Linear(128, 256, **linear_kwargs),
+        te.Linear(256, 256 * 3, **linear_kwargs),
+        te.Linear(256 * 3, 128, **linear_kwargs),
+    )
+
+    # Make sure the BF16 model and FP8 model have the same initial weights
+    for w_fp8, w in zip(model_fp8.parameters(), model.parameters()):
+        high_precision_init_val = w_fp8.get_high_precision_init_val()
+        w.data.copy_(high_precision_init_val)
+
+    # Allocate main_grads for each weight
+    for w_fp8, w in zip(model_fp8.parameters(), model.parameters()):
+        w_fp8.main_grad = torch.zeros_like(w_fp8, dtype=torch.float32, device="cuda")
+        w.main_grad = torch.zeros_like(w, dtype=torch.float32, device="cuda")
+
+    optimizer_fp8 = MiniZero_1([w for w in model_fp8.parameters()], 10.0, dp_group)
+    optimizer = MiniZero_1([w for w in model.parameters()], 10.0, dp_group)
+
+    for _ in range(100):
+        for w_fp8, w in zip(model_fp8.parameters(), model.parameters()):
+            w_fp8.main_grad.zero_()
+            w.main_grad.zero_()
+
+        inputs = [
+            torch.randn(16, 128, dtype=torch.bfloat16, device="cuda") for _ in range(world_size)
+        ]
+        # Choose based on rank to make sure the inputs of different ranks are different.
+        x = inputs[rank]
+
+        with te.fp8.fp8_autocast(
+            enabled=quantization is not None,
+            fp8_recipe=quantization_recipe(quantization),
+            fp8_group=mock_group,
+        ):
+            y_fp8 = model_fp8(x)
+
+        with te.fp8_autocast(
+            enabled=quantization is not None,
+            fp8_recipe=quantization_recipe(quantization),
+            fp8_group=mock_group,
+        ):
+            y = model(x)
+
+        targets = [torch.randn_like(y) for _ in range(world_size)]
+        # Choose based on rank to make sure the targets of different ranks are different.
+        target = targets[rank]
+        loss_fp8 = nn.MSELoss()(y_fp8, target)
+        loss = nn.MSELoss()(y, target)
+
+        loss_fp8.backward()
+        loss.backward()
+
+        optimizer_fp8.step()
+        optimizer.step()
+
+        torch.testing.assert_close(loss_fp8, loss, atol=0, rtol=0)
+
+
+def main(argv=None, namespace=None):
+    WORLD_RANK = int(os.getenv("RANK", "0"))
+    WORLD_SIZE = int(os.getenv("WORLD_SIZE", "1"))
+    LOCAL_RANK = int(os.getenv("LOCAL_RANK", "0"))
+    LOCAL_SIZE = int(os.getenv("LOCAL_WORLD_SIZE", "1"))
+
+    assert WORLD_SIZE == LOCAL_SIZE  # this test supports only 1 node
+    assert LOCAL_SIZE <= torch.cuda.device_count()
+    dist_init_kwargs = {
+        "backend": "nccl",
+        "rank": WORLD_RANK,
+        "world_size": WORLD_SIZE,
+        "timeout": datetime.timedelta(seconds=30),
+    }
+    dist_init_kwargs["init_method"] = "env://"
+    dist_init_kwargs["device_id"] = torch.device(f"cuda:{LOCAL_RANK}")
+    assert dist.is_nccl_available()
+    torch.cuda.set_device(LOCAL_RANK)
+    dist.init_process_group(**dist_init_kwargs)
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--quantization", type=str, default=None, choices=["fp8", "fp8_cs"])
+    args = parser.parse_args(argv, namespace)
+
+    dp_group = dist.new_group(backend="nccl")
+    _test_zero_1(dp_group)
+    _test_cast_master_weights_to_fp8(args.quantization, dp_group)
+    _test_fsdp_cast_master_weights_to_fp8(args.quantization, dp_group)
+
+    dist.destroy_process_group()
+    return 0
+
+
+if __name__ == "__main__":
+
+    sys.exit(main())
diff --git a/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
new file mode 100644
index 0000000..8ebe86b
--- /dev/null
+++ b/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py
@@ -0,0 +1,35 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+import os
+import subprocess
+from pathlib import Path
+
+import pytest
+import torch
+from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
+
+
+if torch.cuda.device_count() < 2:
+    pytest.skip("cast_master_weights_to_fp8 test needs at least 2 GPUs.")
+
+fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
+
+TEST_ROOT = Path(__file__).parent.resolve()
+NUM_PROCS: int = min(2, torch.cuda.device_count())
+LAUNCH_CMD = ["torchrun", f"--nproc_per_node={NUM_PROCS}"]
+
+
+def _run_test(quantization):
+    test_path = TEST_ROOT / "run_cast_master_weights_to_fp8.py"
+    test_cmd = LAUNCH_CMD + [str(test_path)] + ["--quantization", quantization]
+    result = subprocess.run(test_cmd, env=os.environ, check=False)
+    assert result.returncode == 0
+
+
+@pytest.mark.parametrize("quantization", ["fp8", "fp8_cs"])
+def test_cast_master_weights_to_fp8(quantization):
+    if not fp8_available:
+        pytest.skip(reason_for_no_fp8)
+    _run_test(quantization)
diff --git a/tests/pytorch/references/ref_per_tensor_cs.py b/tests/pytorch/references/ref_per_tensor_cs.py
index 1895b31..dad0c42 100644
--- a/tests/pytorch/references/ref_per_tensor_cs.py
+++ b/tests/pytorch/references/ref_per_tensor_cs.py
@@ -8,12 +8,8 @@ import transformer_engine_torch as tex
 from transformer_engine.pytorch.constants import TE_DType_To_Torch
 
 
-# compute amax and scale
-def _ref_compute_amax_scale(x, quant_dtype, eps, pow_2_scales):
-    x_fp32 = x.to(torch.float32)
-    amax = torch.amax(torch.abs(x_fp32)).view(1)
-    assert amax.dtype == torch.float, "amax must be a float tensor."
-    fp8_max = torch.finfo(quant_dtype).max
+# Compute scale and scale_inv from amax
+def _ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales):
     # Clamping amax to avoid division by small numbers
     amax = torch.max(amax, torch.tensor(eps))
 
@@ -52,6 +48,20 @@ def _ref_compute_amax_scale(x, quant_dtype, eps, pow_2_scales):
     # Compute scale_inv
     scale_inv = torch.reciprocal(scale)
 
+    return scale, scale_inv
+
+
+# compute amax and scale
+def _ref_compute_amax_scale(x, quant_dtype, eps, pow_2_scales):
+    x_fp32 = x.to(torch.float32)
+    amax = torch.amax(torch.abs(x_fp32)).view(1)
+    assert amax.dtype == torch.float, "amax must be a float tensor."
+    fp8_max = torch.finfo(quant_dtype).max
+
+    scale, scale_inv = _ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales)
+    # Clamping amax to avoid division by small numbers
+    amax = torch.max(amax, torch.tensor(eps))
+
     return scale, scale_inv, amax
 
 
@@ -103,3 +113,7 @@ def ref_per_tensor_cs_cast(
         qx_t = _multi_dim_transpose(qx)
         sx_t = sx
     return qx, sx, qx_t, sx_t
+
+
+def ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales):
+    return _ref_compute_scale_and_scale_inv_from_amax(amax, fp8_max, eps, pow_2_scales)
diff --git a/tests/pytorch/test_multi_tensor.py b/tests/pytorch/test_multi_tensor.py
index ecc06c3..4dc1ec0 100644
--- a/tests/pytorch/test_multi_tensor.py
+++ b/tests/pytorch/test_multi_tensor.py
@@ -9,6 +9,9 @@ import transformer_engine.pytorch as te
 import transformer_engine_torch as tex
 from transformer_engine.pytorch.optimizers import MultiTensorApply
 
+from references.ref_per_tensor_cs import ref_compute_scale_and_scale_inv_from_amax
+
+
 input_size_pairs = [
     (7777 * 77, 555 * 555),
     (777, 555),
@@ -216,3 +219,42 @@ def test_multi_tensor_unscale_l2norm(input_size_pair, applier, repeat, in_type,
     if per_tensor:
         torch.testing.assert_close(norm_per_tensor, normab.broadcast_to(norm_per_tensor.shape))
     assert overflow_buf.item() == 0
+
+
+@pytest.mark.parametrize("input_size_pair", input_size_pairs + [(1, 1)])
+@pytest.mark.parametrize("applier", appliers)
+@pytest.mark.parametrize("repeat", [1, 55])
+@pytest.mark.parametrize("max_fp8", [448.0, 57344.0])
+@pytest.mark.parametrize("pow_2_scales", [False, True])
+@pytest.mark.parametrize("epsilon", [0.0, 100.0])
+def test_multi_tensor_compute_scale_and_scale_inv(
+    input_size_pair, applier, repeat, max_fp8, pow_2_scales, epsilon
+):
+    sizea, sizeb = input_size_pair
+    device = torch.device("cuda")
+    overflow_buf = torch.zeros(1, dtype=torch.int32, device=device)
+    a = torch.randn([sizea], dtype=torch.float32, device=device).abs()
+    b = torch.randn([sizeb], dtype=torch.float32, device=device).abs()
+
+    amax_list = []
+    for i in range(repeat):
+        amax_list += [a.clone(), b.clone()]
+
+    scale_list = [torch.empty_like(x) for x in amax_list]
+    scale_inv_list = [torch.empty_like(x) for x in amax_list]
+
+    applier(
+        tex.multi_tensor_compute_scale_and_scale_inv,
+        overflow_buf,
+        [amax_list, scale_list, scale_inv_list],
+        max_fp8,
+        pow_2_scales,
+        epsilon,
+    )
+
+    for amax, scale, scale_inv in zip(amax_list, scale_list, scale_inv_list):
+        scale_ref, scale_inv_ref = ref_compute_scale_and_scale_inv_from_amax(
+            amax, max_fp8, epsilon, pow_2_scales
+        )
+        torch.testing.assert_close(scale, scale_ref, rtol=0, atol=0)
+        torch.testing.assert_close(scale_inv, scale_inv_ref, rtol=0, atol=0)
diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py
index 1e6250f..980eeef 100644
--- a/tests/pytorch/test_sanity.py
+++ b/tests/pytorch/test_sanity.py
@@ -36,7 +36,12 @@ from transformer_engine.common import recipe
 import transformer_engine_torch as tex
 from transformer_engine.pytorch.cpp_extensions import general_gemm
 from transformer_engine.pytorch.module.base import get_workspace
-from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer
+from transformer_engine.pytorch.tensor import QuantizedTensor
+from transformer_engine.pytorch.tensor.float8_tensor import (
+    Float8Quantizer,
+    Float8CurrentScalingQuantizer,
+)
+from transformer_engine.pytorch.tensor.utils import replace_raw_data
 from test_numerics import reset_rng_states, dtype_tols
 
 # Only run FP8 tests on supported devices.
@@ -1196,3 +1201,70 @@ def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False
             outputs.append(p.grad)
 
     return outputs
+
+
+@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
+def test_replace_raw_data_for_float8tensor():
+    """Test the functionality of replace_raw_data"""
+    torch.manual_seed(12345)
+    torch.cuda.manual_seed(12345)
+
+    fp8_quantizer = Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda")
+    fp8_tensor = fp8_quantizer.make_empty([128, 128], dtype=torch.bfloat16, device="cuda")
+    random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda")
+    fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor)
+
+    attrs_to_check = ["_quantizer", "_fp8_dtype", "_scale_inv", "_transpose", "_transpose_invalid"]
+    attrs = {}
+    for attr in attrs_to_check:
+        attrs[attr] = getattr(fp8_tensor, attr)
+
+    old_data = fp8_tensor._data
+    new_data = torch.empty_like(old_data)
+    replace_raw_data(fp8_tensor, new_data)
+
+    # Make sure the new_data is properly assigned.
+    assert fp8_tensor._data.data_ptr() != old_data.data_ptr()
+    assert fp8_tensor._data.data_ptr() == new_data.data_ptr()
+    # Make sure the values are not changed.
+    torch.testing.assert_close(old_data, fp8_tensor._data, atol=0, rtol=0)
+    # Make sure other attributes are not changed (totally identical)
+    for attr in attrs_to_check:
+        assert id(getattr(fp8_tensor, attr)) == id(attrs[attr])
+
+
+@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
+def test_fp8_model_init_high_precision_init_val():
+    """Test fp8_model_init with preserve_high_precision_init_val=True"""
+    with fp8_model_init(preserve_high_precision_init_val=True):
+        model = Linear(768, 768)
+
+    weight = model.weight
+
+    assert isinstance(weight, QuantizedTensor), "Weight should be QuantizedTensor"
+    assert hasattr(weight, "_high_precision_init_val"), "_high_precision_init_val not found"
+    assert hasattr(weight, "get_high_precision_init_val"), "get_high_precision_init_val() not found"
+    assert hasattr(
+        weight, "clear_high_precision_init_val"
+    ), "clear_high_precision_init_val() not found"
+
+    high_precision = weight.get_high_precision_init_val()
+    assert high_precision.device.type == "cpu", "high_precision_init_val is not on the CPU"
+
+    new_weight = weight._get_quantizer().make_empty(
+        shape=weight.shape, dtype=weight.dtype, device=weight.device
+    )
+    weight._get_quantizer().update_quantized(high_precision.to(weight.device), new_weight)
+
+    torch.testing.assert_close(
+        new_weight.dequantize(dtype=weight.dtype),
+        weight.dequantize(dtype=weight.dtype),
+        rtol=0,
+        atol=0,
+    )
+
+    weight.clear_high_precision_init_val()
+    assert weight.get_high_precision_init_val() is None, "clear_high_precision_init_val() not work"
+    assert not hasattr(
+        weight, "._high_precision_init_val"
+    ), "clear_high_precision_init_val() not work"
diff --git a/transformer_engine/common/recipe/current_scaling.cu b/transformer_engine/common/recipe/current_scaling.cu
index 3a25d71..cf07d12 100644
--- a/transformer_engine/common/recipe/current_scaling.cu
+++ b/transformer_engine/common/recipe/current_scaling.cu
@@ -13,6 +13,7 @@
 #include "../common.h"
 #include "../util/logging.h"
 #include "../util/vectorized_pointwise.h"
+#include "recipe_common.cuh"
 
 namespace transformer_engine {
 namespace {
@@ -135,7 +136,7 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt
              "Output tensor for amax computation has invalid amax tensor  "
              "(expected FP32, got dtype=",
              to_string(output.amax.dtype), ")");
-  CheckOutputTensor(output, "output_compute_amax");
+  CheckOutputTensor(output, "output_compute_amax", true);
 
   // Compute amax
   TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
@@ -151,41 +152,7 @@ namespace {
 __global__ void compute_scale_from_amax_kernel(const float *amax_ptr, float *scale_ptr,
                                                const float max_fp8, const bool force_pow_2_scales,
                                                const float epsilon) {
-  float amax = *amax_ptr;
-  if (amax < epsilon) {
-    amax = epsilon;
-  }
-
-  float scale = 1.f;
-
-  if (isinf(amax) || amax == 0.f) {
-    *scale_ptr = scale;
-    return;
-  }
-
-  scale = max_fp8 / amax;
-
-  // The amax is too small that the scale becoming infinite in FP32. In other word,
-  // the scale is not representable in FP32.
-  if (isinf(scale)) {
-    // use fp32 max to represent the scale
-    scale = std::numeric_limits<float>::max();
-  }
-
-  if (isnan(scale)) {
-    scale = 1.f;
-  }
-
-  if (force_pow_2_scales) {
-    uint32_t scale_bits = *reinterpret_cast<uint32_t *>(&scale);
-    scale_bits &= 0xFF800000;
-    // If the exponent was zero, we have a logic error.
-    __builtin_assume(scale_bits != 0);
-    __builtin_assume(scale_bits != 0x80000000);
-    scale = *reinterpret_cast<float *>(&scale_bits);
-  }
-
-  *scale_ptr = scale;
+  *scale_ptr = compute_scale_from_amax(*amax_ptr, max_fp8, force_pow_2_scales, epsilon);
 }
 
 }  // namespace
diff --git a/transformer_engine/common/recipe/recipe_common.cuh b/transformer_engine/common/recipe/recipe_common.cuh
new file mode 100644
index 0000000..c789a9b
--- /dev/null
+++ b/transformer_engine/common/recipe/recipe_common.cuh
@@ -0,0 +1,56 @@
+/*************************************************************************
+ * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+#ifndef TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_
+#define TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_
+
+#include <limits>
+
+namespace transformer_engine {
+
+__device__ __forceinline__ float compute_scale_from_amax(float amax, float max_fp8,
+                                                         bool force_pow_2_scales, float epsilon) {
+  if (amax < epsilon) {
+    amax = epsilon;
+  }
+
+  float scale = 1.f;
+
+  if (isinf(amax) || amax == 0.f) {
+    return scale;
+  }
+
+  // Here we don't use "scale = max_fp8 / amax" because it has different results with/without
+  // "--use_fast_math".
+  // "__fdiv_rn" has the same behavior with "max_fp8 / amax" when not using fast math.
+  scale = __fdiv_rn(max_fp8, amax);
+
+  // The amax is too small that the scale becoming infinite in FP32. In other word,
+  // the scale is not representable in FP32.
+  if (isinf(scale)) {
+    // use fp32 max to represent the scale
+    scale = std::numeric_limits<float>::max();
+  }
+
+  if (isnan(scale)) {
+    scale = 1.f;
+  }
+
+  if (force_pow_2_scales) {
+    uint32_t scale_bits = *reinterpret_cast<uint32_t *>(&scale);
+    scale_bits &= 0xFF800000;
+    // If the exponent was zero, we have a logic error.
+    __builtin_assume(scale_bits != 0);
+    __builtin_assume(scale_bits != 0x80000000);
+    scale = *reinterpret_cast<float *>(&scale_bits);
+  }
+
+  return scale;
+}
+
+}  // namespace transformer_engine
+
+#endif  // TRANSFORMER_ENGINE_RECIPE_RECIPE_COMMON_CUH_
diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h
index e430be0..9561fda 100644
--- a/transformer_engine/pytorch/csrc/extensions.h
+++ b/transformer_engine/pytorch/csrc/extensions.h
@@ -252,6 +252,8 @@ at::Tensor scaled_aligned_causal_masked_softmax_backward(at::Tensor output_grads
  * FP8 recipe
  **************************************************************************************************/
 
+void compute_amax(const at::Tensor &tensor, at::Tensor &amax);
+
 void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer,
                                                  std::vector<at::Tensor> amax_histories,
                                                  std::vector<at::Tensor> scales,
@@ -359,6 +361,10 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag,
                            float momentum, float dampening, float lr, bool nesterov, bool first_run,
                            bool wd_after_momentum, float scale);
 
+void multi_tensor_compute_scale_and_scale_inv_cuda(
+    int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
+    float max_fp8, bool force_pow_2_scales, float epsilon);
+
 /***************************************************************************************************
  * padding
  **************************************************************************************************/
diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_compute_scale.cu b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_compute_scale.cu
new file mode 100644
index 0000000..d262767
--- /dev/null
+++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/multi_tensor_compute_scale.cu
@@ -0,0 +1,66 @@
+/*************************************************************************
+ * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+ *
+ * See LICENSE for license information.
+ ************************************************************************/
+
+#include <ATen/ATen.h>
+#include <ATen/AccumulateType.h>
+#include <ATen/cuda/CUDAContext.h>
+#include <ATen/cuda/Exceptions.h>
+// Another possibility:
+// #include <torch/all.h>
+
+#include <assert.h>
+// Stringstream is a big hammer, but I want to rely on operator<< for dtype.
+#include <sstream>
+
+#include "common/recipe/recipe_common.cuh"
+#include "common/utils.cuh"
+#include "multi_tensor_apply.cuh"
+#include "type_shim.h"
+
+#define BLOCK_SIZE 256
+
+struct ComputeScaleAndScaleInvFunctor {
+  __device__ __forceinline__ void operator()(int chunk_size, volatile int *noop_gmem,
+                                             TensorListMetadata<3> &tl,  // NOLINT(*)
+                                             float max_fp8, bool force_pow_2_scales,
+                                             float epsilon) {
+    // I'd like this kernel to propagate infs/nans.
+    // if(*noop_gmem == 1)
+    //   return;
+
+    int tensor_loc = tl.block_to_tensor[blockIdx.x];
+    int chunk_idx = tl.block_to_chunk[blockIdx.x];
+    int n = tl.sizes[tensor_loc];
+
+    float *amax = reinterpret_cast<float *>(tl.addresses[0][tensor_loc]);
+    amax += chunk_idx * chunk_size;
+
+    float *scale = reinterpret_cast<float *>(tl.addresses[1][tensor_loc]);
+    scale += chunk_idx * chunk_size;
+
+    float *scale_inv = reinterpret_cast<float *>(tl.addresses[2][tensor_loc]);
+    scale_inv += chunk_idx * chunk_size;
+
+    n -= chunk_idx * chunk_size;
+
+    for (int i_start = threadIdx.x; i_start < n && i_start < chunk_size; i_start += blockDim.x) {
+      float scale_val = transformer_engine::compute_scale_from_amax(amax[i_start], max_fp8,
+                                                                    force_pow_2_scales, epsilon);
+      scale[i_start] = scale_val;
+      transformer_engine::reciprocal(scale_inv + i_start, scale_val);
+    }
+  }
+};
+
+void multi_tensor_compute_scale_and_scale_inv_cuda(
+    int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists,
+    float max_fp8, bool force_pow_2_scales, float epsilon) {
+  using namespace at;
+
+  multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists,
+                        ComputeScaleAndScaleInvFunctor(), max_fp8, force_pow_2_scales, epsilon);
+  AT_CUDA_CHECK(cudaGetLastError());
+}
diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp
index a58fd3a..097cf63 100644
--- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp
+++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp
@@ -178,6 +178,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
         py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard<py::gil_scoped_release>());
   m.def("get_fused_attn_backend", &get_fused_attn_backend, "Get Fused Attention backend",
         py::call_guard<py::gil_scoped_release>());
+  m.def("compute_amax", &compute_amax, "Compute amax", py::arg("input"), py::arg("amax"));
   m.def("fused_amax_and_scale_update_after_reduction", &fused_amax_and_scale_update_after_reduction,
         "Update amax history and FP8 scale/scale_inv after reduction",
         py::call_guard<py::gil_scoped_release>());
@@ -265,6 +266,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
   m.def("multi_tensor_sgd", &multi_tensor_sgd_cuda,
         "Fused SGD optimizer for list of contiguous tensors",
         py::call_guard<py::gil_scoped_release>());
+  m.def("multi_tensor_compute_scale_and_scale_inv", &multi_tensor_compute_scale_and_scale_inv_cuda,
+        "Fused compute scale and scale_inv from amax", py::call_guard<py::gil_scoped_release>());
 
   // Data structures
   py::class_<transformer_engine::pytorch::FP8TensorMeta>(m, "FP8TensorMeta")
diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cpp b/transformer_engine/pytorch/csrc/extensions/recipe.cpp
index e8a31da..2dc3b69 100644
--- a/transformer_engine/pytorch/csrc/extensions/recipe.cpp
+++ b/transformer_engine/pytorch/csrc/extensions/recipe.cpp
@@ -12,10 +12,27 @@
 #include "common/common.h"
 #include "extensions.h"
 
-void fused_amax_and_scale_update_after_reduction(const at::Tensor &amax_reduction_buffer,
+void compute_amax(const at::Tensor& tensor, at::Tensor& amax) {
+  using namespace transformer_engine;
+  using namespace transformer_engine::pytorch;
+
+  auto input_tensor = tensor.contiguous();
+  const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor);
+
+  TORCH_CHECK(amax.scalar_type() == at::kFloat, "amax must be a float tensor");
+  TORCH_CHECK(amax.numel() == 1, "amax must have exactly one element");
+  TensorWrapper fake_te_output(
+      nullptr, te_input.shape(),
+      transformer_engine::DType::kFloat8E4M3,  // It doesn't matter because we only compute amax.
+      amax.data_ptr<float>());
+
+  nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream());
+}
+
+void fused_amax_and_scale_update_after_reduction(const at::Tensor& amax_reduction_buffer,
                                                  std::vector<at::Tensor> amax_histories,
                                                  std::vector<at::Tensor> scales,
-                                                 const std::string &amax_compute_algo,
+                                                 const std::string& amax_compute_algo,
                                                  transformer_engine::DType fp8_dtype,
                                                  float margin) {
   using namespace transformer_engine;
diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py
index 87298c2..38f829c 100644
--- a/transformer_engine/pytorch/fp8.py
+++ b/transformer_engine/pytorch/fp8.py
@@ -93,6 +93,7 @@ class FP8GlobalStateManager:
     FP8_RECIPE = None
     FP8_DISTRIBUTED_GROUP = None
     FP8_PARAMETERS = False
+    HIGH_PRECISION_INIT_VAL = False
     IS_FIRST_FP8_MODULE = False
     FP8_GRAPH_CAPTURING = False
     FP8_AUTOCAST_DEPTH = 0
@@ -117,6 +118,7 @@ class FP8GlobalStateManager:
         cls.FP8_RECIPE = None
         cls.FP8_DISTRIBUTED_GROUP = None
         cls.FP8_PARAMETERS = False
+        cls.HIGH_PRECISION_INIT_VAL = False
         cls.IS_FIRST_FP8_MODULE = False
         cls.FP8_GRAPH_CAPTURING = False
         cls.FP8_AUTOCAST_DEPTH = 0
@@ -267,6 +269,11 @@ class FP8GlobalStateManager:
         """Should the parameters be stored as FP8"""
         return cls.FP8_PARAMETERS
 
+    @classmethod
+    def with_high_precision_init_val(cls) -> bool:
+        """Should the high precision initial values be stored with FP8 parameters"""
+        return cls.HIGH_PRECISION_INIT_VAL
+
     @classmethod
     def fp8_graph_capturing(cls) -> bool:
         """Is CUDA graph capture under way?"""
@@ -500,7 +507,11 @@ class FP8GlobalStateManager:
 
 
 @contextmanager
-def fp8_model_init(enabled: bool = True, recipe: Optional[Recipe] = None) -> None:
+def fp8_model_init(
+    enabled: bool = True,
+    recipe: Optional[Recipe] = None,
+    preserve_high_precision_init_val: bool = False,
+) -> None:
     """
     Context manager for FP8 initialization of parameters.
 
@@ -511,6 +522,12 @@ def fp8_model_init(enabled: bool = True, recipe: Optional[Recipe] = None) -> Non
         with fp8_model_init(enabled=True):
             model = transformer_engine.pytorch.Linear(768, 768)
 
+        # Preserving high precision initial value to initialize master weight
+        with fp8_model_init(enabled=True, preserve_high_precision_init_val=True):
+            model = transformer_engine.pytorch.Linear(768, 768)
+        master_weight = model.weight.get_high_precision_init_val()
+        model.weight.clear_high_precision_init_val()
+
     Parameters
     ----------
     enabled: bool, default = `True`
@@ -526,18 +543,29 @@ def fp8_model_init(enabled: bool = True, recipe: Optional[Recipe] = None) -> Non
              * LoRA-like fine-tuning, where the main parameters of the model do not change.
     recipe: transformer_engine.common.recipe.Recipe, default = `None`
             Recipe used to create the parameters. If left to None, it uses the default FP8 recipe.
+    preserve_high_precision_init_val: bool, default = `False`
+             when enabled, store the high precision tensor used to initialize FP8 parameters
+             in CPU memory, and add two function attributes named `get_high_precision_init_val()`
+             and `clear_high_precision_init_val()` to FP8 parameters to get/clear this high
+             precision tensor. The purpose is that users can use this high-precision copy
+             to initialize master weights, avoiding the loss of precision that can occur when
+             using FP8 parameters directly. Note that after the master weights are initialized,
+             users should call `clear_high_precision_init_val()` to release this CPU memory.
 
              This functionality is *EXPERIMENTAL*.
     """
     _fp8_parameters = FP8GlobalStateManager.FP8_PARAMETERS
     _fp8_recipe = FP8GlobalStateManager.FP8_RECIPE
+    _high_precision_init_val = FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL
     FP8GlobalStateManager.FP8_PARAMETERS = enabled
     FP8GlobalStateManager.FP8_RECIPE = get_default_fp8_recipe() if recipe is None else recipe
+    FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = preserve_high_precision_init_val
     try:
         yield
     finally:
         FP8GlobalStateManager.FP8_PARAMETERS = _fp8_parameters
         FP8GlobalStateManager.FP8_RECIPE = _fp8_recipe
+        FP8GlobalStateManager.HIGH_PRECISION_INIT_VAL = _high_precision_init_val
 
 
 @contextmanager
diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py
index 4b82054..cdb75aa 100644
--- a/transformer_engine/pytorch/module/base.py
+++ b/transformer_engine/pytorch/module/base.py
@@ -10,6 +10,7 @@ import warnings
 from abc import ABC, abstractmethod
 from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union
 from contextlib import contextmanager
+from types import MethodType
 
 import torch
 import torch.nn.functional as F
@@ -405,6 +406,7 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
         self.sequence_parallel = False
         self.param_init_meta = {}
         self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters()
+        self.preserve_high_precision_init_val = FP8GlobalStateManager.with_high_precision_init_val()
         self.fsdp_wrapped = False
         self.fsdp_group = None
         self._fp8_workspaces: Dict[str, QuantizedTensor] = {}
@@ -902,7 +904,11 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
 
             # If primary weights are in fp8, wrap the parameter as FP8Tensor
             fp8_meta_index = self.param_init_meta[name].fp8_meta_index
+            high_precision_init_val = None
             if self.primary_weights_in_fp8 and fp8_meta_index is not None:
+                if self.preserve_high_precision_init_val:
+                    high_precision_init_val = param.detach().cpu()
+
                 quantizer = self.quantizers["scaling_fwd"][fp8_meta_index]
                 assert (
                     quantizer is not None
@@ -914,7 +920,34 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
             # NOTE: Currently this can only be broken when primary weights are in Fp8 but
             #       re-applying the nn.Parameter() wrap is a no-op when the input is already
             #       a parameter so we always re-apply it just for extra safety.
-            setattr(self, name, torch.nn.Parameter(param))
+            param = torch.nn.Parameter(param)
+            if high_precision_init_val is not None:
+
+                # - Master weights are initialized from model weights, if we use fp8 primary
+                #   weights to initialize master weights, the numerical values of master weights
+                #   are not consistent with the numerical values when we initialize them from
+                #   bf16/fp16 weights.
+                # - So we add a `_high_precision_init_val` attribute to each model weight to store
+                #   the original bf16/fp16 weight on cpu before casting it to fp8. And users can
+                #   use `get_high_precision_init_val` to get this cpu tensor.
+                # - This cpu tensor is not needed once the master weight is initialized, so users
+                #   should call `clear_high_precision_init_val` to remove it after master weight
+                #   is initialized.
+
+                def get(self):
+                    if hasattr(self, "_high_precision_init_val"):
+                        return self._high_precision_init_val
+                    return None
+
+                def clear(self):
+                    if hasattr(self, "_high_precision_init_val"):
+                        del self._high_precision_init_val
+
+                param._high_precision_init_val = high_precision_init_val
+                param.get_high_precision_init_val = MethodType(get, param)
+                param.clear_high_precision_init_val = MethodType(clear, param)
+
+            setattr(self, name, param)
 
     @abstractmethod
     def forward(self):
@@ -953,10 +986,26 @@ class TransformerEngineBaseModule(torch.nn.Module, ABC):
             FSDP process group that the weights are distributed over.
         """
 
+        # FP8 primary weights
+        if isinstance(tensor, QuantizedTensor):
+            if update_workspace and quantizer is not None:
+                tensor.update_usage(
+                    rowwise_usage=quantizer.rowwise_usage,
+                    columnwise_usage=quantizer.columnwise_usage,
+                )
+            return tensor
+
         # Try getting workspace from cache
         out = None
         if cache_name is not None:
             out = self._fp8_workspaces.get(cache_name, None)
+            if quantizer is not None and isinstance(out, MXFP8TensorBase):
+                if quantizer.rowwise_usage and out._rowwise_data is None:
+                    out = None
+                    del self._fp8_workspaces[cache_name]
+                elif quantizer.columnwise_usage and out._columnwise_data is None:
+                    out = None
+                    del self._fp8_workspaces[cache_name]
 
         # Gather cached Fp8 workspace if it's distributed
         # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work
diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py
index 8bf420a..8963a61 100644
--- a/transformer_engine/pytorch/module/grouped_linear.py
+++ b/transformer_engine/pytorch/module/grouped_linear.py
@@ -130,20 +130,17 @@ class _GroupedLinear(torch.autograd.Function):
             )
             weights_fp8 = []
             bias_dtype = torch.bfloat16 if activation_dtype == torch.float32 else activation_dtype
-            if not isinstance(weights[0], QuantizedTensor):
-                # FP8 cast to workspace buffer
-                update_workspace = is_first_microbatch is None or is_first_microbatch
-                for i in range(num_gemms):
-                    weight_fp8 = module.get_weight_workspace(
-                        tensor=weights[i],
-                        quantizer=weight_quantizers[i],
-                        cache_name=(None if is_first_microbatch is None else f"weight{i}"),
-                        update_workspace=update_workspace,
-                        skip_update_flag=skip_fp8_weight_update,
-                    )
-                    weights_fp8.append(weight_fp8)
-            else:
-                weights_fp8 = weights
+            # FP8 cast to workspace buffer
+            update_workspace = is_first_microbatch is None or is_first_microbatch
+            for i in range(num_gemms):
+                weight_fp8 = module.get_weight_workspace(
+                    tensor=weights[i],
+                    quantizer=weight_quantizers[i],
+                    cache_name=(None if is_first_microbatch is None else f"weight{i}"),
+                    update_workspace=update_workspace,
+                    skip_update_flag=skip_fp8_weight_update,
+                )
+                weights_fp8.append(weight_fp8)
 
         else:
             inputmats = inputmats_no_fp8
@@ -180,7 +177,7 @@ class _GroupedLinear(torch.autograd.Function):
                     weight_quantizers[i].calibrate(weights[i])
 
         if is_grad_enabled:
-
+            ctx.weight_quantizers = weight_quantizers
             ctx.weights_shape_1 = weights[0].shape[1]
 
             tensors_to_save, tensor_objects = prepare_for_saving(*inputmats, *weights_fp8, *biases)
@@ -270,6 +267,12 @@ class _GroupedLinear(torch.autograd.Function):
                     device=ctx.device,
                 )
 
+                for weight, quantizer in zip(weights, ctx.weight_quantizers):
+                    if quantizer is not None and isinstance(weight, QuantizedTensor):
+                        weight.update_usage(
+                            rowwise_usage=quantizer.rowwise_usage,
+                            columnwise_usage=quantizer.columnwise_usage,
+                        )
                 general_grouped_gemm(
                     weights,
                     grad_output,
diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py
index 4022924..fc316e3 100644
--- a/transformer_engine/pytorch/module/layernorm_linear.py
+++ b/transformer_engine/pytorch/module/layernorm_linear.py
@@ -262,28 +262,26 @@ class _LayerNormLinear(torch.autograd.Function):
         nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm")
 
         # Cast weight to expected dtype
-        weightmat = weight
-        quantized_weight = False
         if not fp8:
-            weightmat = cast_if_needed(weightmat, activation_dtype)
+            quantized_weight = False
+            weightmat = cast_if_needed(weight, activation_dtype)
         else:
-            if not isinstance(weight, QuantizedTensor):
-                quantized_weight = True
-
-                # Configure quantizer
-                if weight_quantizer is not None:
-                    weight_quantizer.set_usage(rowwise=True, columnwise=True)
-
-                # FP8 cast to workspace buffer
-                update_workspace = is_first_microbatch is None or is_first_microbatch
-                weightmat = module.get_weight_workspace(
-                    tensor=weight,
-                    quantizer=weight_quantizer,
-                    cache_name=(None if is_first_microbatch is None else "weight"),
-                    update_workspace=update_workspace,
-                    skip_update_flag=skip_fp8_weight_update,
-                    fsdp_group=fsdp_group,
-                )
+            quantized_weight = not isinstance(weight, QuantizedTensor)
+
+            # Configure quantizer
+            if weight_quantizer is not None:
+                weight_quantizer.set_usage(rowwise=True, columnwise=True)
+
+            # FP8 cast to workspace buffer
+            update_workspace = is_first_microbatch is None or is_first_microbatch
+            weightmat = module.get_weight_workspace(
+                tensor=weight,
+                quantizer=weight_quantizer,
+                cache_name=(None if is_first_microbatch is None else "weight"),
+                update_workspace=update_workspace,
+                skip_update_flag=skip_fp8_weight_update,
+                fsdp_group=fsdp_group,
+            )
 
         # Cast bias to expected dtype
         bias_dtype = activation_dtype
@@ -345,11 +343,12 @@ class _LayerNormLinear(torch.autograd.Function):
                 clear_tensor_data(ln_out, ln_out_total)
 
         if is_grad_enabled:
+            ctx.weight_quantizer = weight_quantizer
             ctx.ln_out_needs_gather = (
                 weight.requires_grad and parallel_mode == "column" and sequence_parallel
             )
 
-            # Input with column-wise usage is needed for dgrad GEMM.
+            # Input with column-wise usage is needed for wgrad GEMM.
             if backward_needs_input:
                 if isinstance(ln_out, QuantizedTensor):
                     # For sequence parallel in vanilla FP8, rowwise data is
@@ -358,6 +357,11 @@ class _LayerNormLinear(torch.autograd.Function):
                     if isinstance(ln_out, MXFP8TensorBase) or not ctx.ln_out_needs_gather:
                         ln_out.update_usage(rowwise_usage=False)
 
+            # Weight with column-wise usage is needed for dgrad GEMM.
+            if inp.requires_grad:
+                if isinstance(weightmat, QuantizedTensor):
+                    weightmat.update_usage(columnwise_usage=True)
+
             if cpu_offloading:
                 if fp8 and weightmat is not None:
                     set_offloading_param(weightmat, "weight_offloading", True)
@@ -642,6 +646,11 @@ class _LayerNormLinear(torch.autograd.Function):
                 if hasattr(recipe, "fp8_gemm_dgrad"):
                     dgrad_gemm_use_split_accumulator = recipe.fp8_gemm_dgrad.use_split_accumulator
 
+            if ctx.weight_quantizer is not None and isinstance(weight, QuantizedTensor):
+                weight.update_usage(
+                    rowwise_usage=ctx.weight_quantizer.rowwise_usage,
+                    columnwise_usage=ctx.weight_quantizer.columnwise_usage,
+                )
             dgrad, *_ = general_gemm(
                 weight,
                 grad_output,
@@ -1274,6 +1283,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
         inp: torch.Tensor,
         is_first_microbatch: Optional[bool] = None,
         fp8_output: Optional[bool] = False,
+        fp8_grad: Optional[bool] = False,
     ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
         """
         Apply layer normalization to the input followed by a linear transformation.
@@ -1304,6 +1314,13 @@ class LayerNormLinear(TransformerEngineBaseModule):
         if skip_fp8_weight_update is not None:
             is_first_microbatch = False
 
+        if self.ub_overlap_rs_fprop:
+            if get_ub(self.ub_name + "_fprop").is_fp8_ubuf():
+                fp8_output = True
+        if self.ub_overlap_rs_dgrad:
+            if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf():
+                fp8_grad = True
+
         with self.prepare_forward(
             inp, allow_non_contiguous=False  # removed .contiguous from inside the layer
         ) as inp:
@@ -1331,7 +1348,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
                 output_quantizer,
                 grad_output_quantizer,
                 grad_input_quantizer,
-            ) = self._get_quantizers(fp8_output)
+            ) = self._get_quantizers(fp8_output, fp8_grad)
 
             if torch.is_grad_enabled():
                 fwd_fn = _LayerNormLinear.apply
@@ -1397,7 +1414,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
             return out, ln_out
         return out
 
-    def _get_quantizers(self, fp8_output):
+    def _get_quantizers(self, fp8_output, fp8_grad):
         if not self.fp8:
             return [None] * 5
         grad_input_quantizer = None
@@ -1412,6 +1429,8 @@ class LayerNormLinear(TransformerEngineBaseModule):
         if torch.is_grad_enabled():
             grad_output_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
             grad_output_quantizer.internal = True
+            if fp8_grad:
+                grad_input_quantizer = self.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
 
         return (
             input_quantizer,
diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py
index 633690b..9cffc47 100644
--- a/transformer_engine/pytorch/module/layernorm_mlp.py
+++ b/transformer_engine/pytorch/module/layernorm_mlp.py
@@ -319,35 +319,31 @@ class _LayerNormMLP(torch.autograd.Function):
                 ln_out_total = ln_out
 
         # Cast weights to expected dtype
-        fc1_weight_final = fc1_weight
-        fc2_weight_final = fc2_weight
         if not fp8:
-            fc1_weight_final = cast_if_needed(fc1_weight_final, activation_dtype)
-            fc2_weight_final = cast_if_needed(fc2_weight_final, activation_dtype)
+            fc1_weight_final = cast_if_needed(fc1_weight, activation_dtype)
+            fc2_weight_final = cast_if_needed(fc2_weight, activation_dtype)
         else:
             # If weights are not quantized, we call get_weight_workspace,
             # which handles weight caching etc.
-            if not isinstance(fc1_weight, QuantizedTensor):
-                # FP8 cast to workspace buffer
-                update_workspace = is_first_microbatch is None or is_first_microbatch
-                fc1_weight_final = module.get_weight_workspace(
-                    tensor=fc1_weight,
-                    quantizer=fc1_weight_quantizer,
-                    cache_name=(None if is_first_microbatch is None else "fc1_weight"),
-                    update_workspace=update_workspace,
-                    skip_update_flag=skip_fp8_weight_update,
-                    fsdp_group=fsdp_group,
-                )
-            if not isinstance(fc2_weight, QuantizedTensor):
-                fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True)
-                fc2_weight_final = module.get_weight_workspace(
-                    tensor=fc2_weight,
-                    quantizer=fc2_weight_quantizer,
-                    cache_name=(None if is_first_microbatch is None else "fc2_weight"),
-                    update_workspace=update_workspace,
-                    skip_update_flag=skip_fp8_weight_update,
-                    fsdp_group=fsdp_group,
-                )
+            # FP8 cast to workspace buffer
+            update_workspace = is_first_microbatch is None or is_first_microbatch
+            fc1_weight_final = module.get_weight_workspace(
+                tensor=fc1_weight,
+                quantizer=fc1_weight_quantizer,
+                cache_name=(None if is_first_microbatch is None else "fc1_weight"),
+                update_workspace=update_workspace,
+                skip_update_flag=skip_fp8_weight_update,
+                fsdp_group=fsdp_group,
+            )
+            fc2_weight_quantizer.set_usage(rowwise=True, columnwise=True)
+            fc2_weight_final = module.get_weight_workspace(
+                tensor=fc2_weight,
+                quantizer=fc2_weight_quantizer,
+                cache_name=(None if is_first_microbatch is None else "fc2_weight"),
+                update_workspace=update_workspace,
+                skip_update_flag=skip_fp8_weight_update,
+                fsdp_group=fsdp_group,
+            )
 
         # Cast biases to expected dtype
         bias_dtype = activation_dtype
@@ -430,7 +426,6 @@ class _LayerNormMLP(torch.autograd.Function):
             dim_size[0] = dim_size[0] // tp_world_size
             dim_size[1] = fc2_weight.size(0)
             rs_out = torch.empty(dim_size, dtype=activation_dtype, device=device)
-            fc2_out = ub_obj_fc2out.get_buffer(output_quantizer)
         else:
             dim_size = list(act_out.size())
             dim_size[1] = fc2_weight.size(0)
@@ -450,6 +445,14 @@ class _LayerNormMLP(torch.autograd.Function):
             ub_type=tex.CommOverlapType.RS if ub_overlap_rs else None,
             extra_output=rs_out,
         )
+
+        # Weight with column-wise usage is needed for dgrad GEMM.
+        if is_grad_enabled and inp.requires_grad:
+            if isinstance(fc1_weight_final, QuantizedTensor):
+                fc1_weight_final.update_usage(columnwise_usage=True)
+            if isinstance(fc2_weight_final, QuantizedTensor):
+                fc2_weight_final.update_usage(columnwise_usage=True)
+
         if not is_grad_enabled:
             clear_tensor_data(act_out, fc1_out_without_bias, fc1_out)
 
@@ -488,6 +491,8 @@ class _LayerNormMLP(torch.autograd.Function):
                 fc2_weight_final if fp8 and not isinstance(fc2_weight, Float8Tensor) else None,
             )
 
+            ctx.fc1_weight_quantizer = fc1_weight_quantizer
+            ctx.fc2_weight_quantizer = fc2_weight_quantizer
             if not fc1_weight.requires_grad:
                 if not return_layernorm_output:
                     clear_tensor_data(ln_out)
@@ -500,11 +505,13 @@ class _LayerNormMLP(torch.autograd.Function):
                 ln_weight,
                 ln_out.clone() if ub_overlap_ag else ln_out,  # avoid saving a UB buffer
                 fc1_weight_final,
+                fc1_weight,
                 fc1_bias,
                 fc1_out,
                 fc1_out_without_bias,
                 act_out,
                 fc2_weight_final,
+                fc2_weight,
                 fc2_bias,
                 mu,
                 rsigma,
@@ -619,11 +626,13 @@ class _LayerNormMLP(torch.autograd.Function):
                 ln_weight,
                 ln_out,
                 fc1_weight,
+                origin_fc1_weight,
                 fc1_bias,
                 fc1_out,
                 fc1_out_without_bias,
                 act_out,
                 fc2_weight,
+                origin_fc2_weight,
                 fc2_bias,
                 mu,
                 rsigma,
@@ -642,7 +651,7 @@ class _LayerNormMLP(torch.autograd.Function):
             )
             fc2_weight_main_grad = (
                 ctx.fc2_main_grad
-                if fc2_weight is not None
+                if origin_fc2_weight is not None
                 and ctx.fuse_wgrad_accumulation
                 and ctx.fc2_weight_requires_grad
                 else None
@@ -651,8 +660,8 @@ class _LayerNormMLP(torch.autograd.Function):
             # For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
             # we need to connect them into one.
             if ctx.fuse_wgrad_accumulation:
-                fc1_weight.main_grad = fc1_weight_main_grad
-                fc2_weight.main_grad = fc2_weight_main_grad
+                origin_fc1_weight.main_grad = fc1_weight_main_grad
+                origin_fc2_weight.main_grad = fc2_weight_main_grad
 
             # TODO: Fix this  # pylint: disable=fixme
             # Gather saved autograd context tensors when running with FSDP
@@ -735,6 +744,11 @@ class _LayerNormMLP(torch.autograd.Function):
             )
 
             # FC2 DGRAD; Unconditional
+            if ctx.fc2_weight_quantizer is not None and isinstance(ctx.fc2_weight, QuantizedTensor):
+                ctx.fc2_weight.update_usage(
+                    rowwise_usage=ctx.fc2_weight_quantizer.rowwise_usage,
+                    columnwise_usage=ctx.fc2_weight_quantizer.columnwise_usage,
+                )
             gemm_output, *_ = general_gemm(
                 fc2_weight,
                 grad_output,
@@ -769,14 +783,18 @@ class _LayerNormMLP(torch.autograd.Function):
                     act_out,
                     grad_output,
                     get_workspace(),
-                    out_dtype=ctx.activation_dtype,
+                    out_dtype=(
+                        origin_fc2_weight.main_grad.dtype
+                        if ctx.fuse_wgrad_accumulation
+                        else ctx.activation_dtype
+                    ),
                     quantization_params=None,  # wgrad in high precision
                     layout="NT",
                     grad=True,
                     bias=fc2_bias if fc2_bias is not None and fc2_bias_grad is None else None,
                     accumulate=accumulate_wgrad_into_param_main_grad,
                     use_split_accumulator=_2X_ACC_WGRAD,
-                    out=fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
+                    out=origin_fc2_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                 )
                 if fc2_bias_grad is None:
                     fc2_bias_grad = fc2_bias_grad_
@@ -864,6 +882,13 @@ class _LayerNormMLP(torch.autograd.Function):
                     fc1_dgrad_bulk = ub_obj_fc1_wgrad.get_buffer(None)
 
             # FC1 DGRAD: Unconditional
+            if ctx.fc1_weight_quantizer is not None and isinstance(
+                ctx.fc1_weight_quantizer, QuantizedTensor
+            ):
+                ctx.fc1_weight.update_usage(
+                    rowwise_usage=ctx.fc1_weight_quantizer.rowwise_usage,
+                    columnwise_usage=ctx.fc1_weight_quantizer.columnwise_usage,
+                )
             fc1_dgrad, *_, fc1_dgrad_rs_out = general_gemm(
                 fc1_weight,
                 dact,
@@ -930,12 +955,16 @@ class _LayerNormMLP(torch.autograd.Function):
                     ln_out_total,
                     dact,
                     get_workspace(),
-                    out_dtype=ctx.activation_dtype,
+                    out_dtype=(
+                        origin_fc1_weight.main_grad.dtype
+                        if ctx.fuse_wgrad_accumulation
+                        else ctx.activation_dtype
+                    ),
                     layout="NT",
                     grad=fuse_gemm_and_bias_fc1_wgrad,
                     bias=fc1_bias if fuse_gemm_and_bias_fc1_wgrad else None,
                     accumulate=accumulate_wgrad_into_param_main_grad,
-                    out=fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
+                    out=origin_fc1_weight.main_grad if ctx.fuse_wgrad_accumulation else None,
                     ub=ub_obj_fc1_wgrad,
                     ub_type=tex.CommOverlapType.RS if ctx.ub_bulk_wgrad else None,
                     extra_output=fc1_dgrad_rs_out,
@@ -996,16 +1025,21 @@ class _LayerNormMLP(torch.autograd.Function):
         if ctx.fc1_weight_requires_grad:
             # Handle custom DDP from mcore.
             if ctx.fuse_wgrad_accumulation and hasattr(fc1_weight, "grad_added_to_main_grad"):
-                fc1_weight.grad_added_to_main_grad = True
-                if getattr(fc1_weight, "zero_out_wgrad", False):
+                origin_fc1_weight.grad_added_to_main_grad = True
+                if getattr(origin_fc1_weight, "zero_out_wgrad", False):
                     fc1_wgrad = torch.zeros(
-                        fc1_weight.main_grad.shape,
-                        dtype=fc1_weight.dtype,
+                        origin_fc1_weight.main_grad.shape,
+                        dtype=origin_fc1_weight.dtype,
                         device=torch.cuda.current_device(),
                         requires_grad=False,
                     )
                 else:
-                    fc1_wgrad = None
+                    fc1_wgrad = torch.empty(
+                        origin_fc1_weight.main_grad.shape,
+                        dtype=origin_fc1_weight.dtype,
+                        device=torch.cuda.current_device(),
+                        requires_grad=False,
+                    )
             elif ctx.fuse_wgrad_accumulation:
                 fc1_wgrad = None
         else:
@@ -1013,17 +1047,24 @@ class _LayerNormMLP(torch.autograd.Function):
 
         if ctx.fc2_weight_requires_grad:
             # Handle custom DDP from mcore.
-            if ctx.fuse_wgrad_accumulation and hasattr(fc2_weight, "grad_added_to_main_grad"):
-                fc2_weight.grad_added_to_main_grad = True
-                if getattr(fc2_weight, "zero_out_wgrad", False):
+            if ctx.fuse_wgrad_accumulation and hasattr(
+                origin_fc2_weight, "grad_added_to_main_grad"
+            ):
+                origin_fc2_weight.grad_added_to_main_grad = True
+                if getattr(origin_fc2_weight, "zero_out_wgrad", False):
                     fc2_wgrad = torch.zeros(
-                        fc2_weight.main_grad.shape,
-                        dtype=fc2_weight.dtype,
+                        origin_fc2_weight.main_grad.shape,
+                        dtype=origin_fc2_weight.dtype,
                         device=torch.cuda.current_device(),
                         requires_grad=False,
                     )
                 else:
-                    fc2_wgrad = None
+                    fc2_wgrad = torch.empty(
+                        origin_fc2_weight.main_grad.shape,
+                        dtype=origin_fc2_weight.dtype,
+                        device=torch.cuda.current_device(),
+                        requires_grad=False,
+                    )
             elif ctx.fuse_wgrad_accumulation:
                 fc2_wgrad = None
         else:
@@ -1429,6 +1470,11 @@ class LayerNormMLP(TransformerEngineBaseModule):
         if skip_fp8_weight_update is not None:
             is_first_microbatch = False
 
+        fp8_output = False
+        if self.ub_overlap_rs:
+            if get_ub("fc2_fprop").is_fp8_ubuf():
+                fp8_output = True
+
         with self.prepare_forward(inp, num_gemms=2) as inp:
             # Get quantizers
             (
@@ -1440,7 +1486,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
                 grad_fc1_output_quantizer,
                 grad_fc2_output_quantizer,
                 grad_input_quantizer,
-            ) = self._get_quantizers()
+            ) = self._get_quantizers(fp8_output)
 
             # Get weight tensors
             fc1_weight = self.fc1_weight
@@ -1528,7 +1574,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
             return out, ln_out
         return out
 
-    def _get_quantizers(self):
+    def _get_quantizers(self, fp8_output):
         (
             fc1_input_quantizer,
             fc1_weight_quantizer,
@@ -1550,6 +1596,8 @@ class LayerNormMLP(TransformerEngineBaseModule):
             )
             fc2_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_WEIGHT]
             fc2_weight_quantizer.internal = True
+            if fp8_output:
+                output_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_OUTPUT]
             if torch.is_grad_enabled():
                 grad_fc2_output_quantizer = self.quantizers["scaling_bwd"][
                     tex.FP8BwdTensors.GRAD_OUTPUT1
diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py
index f96355a..91dfe92 100644
--- a/transformer_engine/pytorch/module/linear.py
+++ b/transformer_engine/pytorch/module/linear.py
@@ -176,31 +176,29 @@ class _Linear(torch.autograd.Function):
         nvtx_range_pop(f"{nvtx_label}.input_cast_comm")
 
         # Cast weight to expected dtype
-        weightmat = weight
         if not fp8:
-            weightmat = cast_if_needed(weightmat, activation_dtype)
+            weightmat = cast_if_needed(weight, activation_dtype)
         else:
-            if not isinstance(weight, QuantizedTensor):
-                # Configure quantizer
-                if weight_quantizer is not None:
-                    columnwise_usage = is_grad_enabled and inp.requires_grad
-                    if not columnwise_usage:
-                        columnwise_usage = (
-                            is_fp8_activation_recompute_enabled()
-                            and not in_fp8_activation_recompute_phase()
-                        )
-                    weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
-
-                # FP8 cast to workspace buffer
-                update_workspace = is_first_microbatch is None or is_first_microbatch
-                weightmat = module.get_weight_workspace(
-                    tensor=weight,
-                    quantizer=weight_quantizer,
-                    cache_name=(None if is_first_microbatch is None else "weight"),
-                    update_workspace=update_workspace,
-                    skip_update_flag=skip_fp8_weight_update,
-                    fsdp_group=fsdp_group,
-                )
+            # Configure quantizer
+            if weight_quantizer is not None:
+                columnwise_usage = is_grad_enabled and inp.requires_grad
+                if not columnwise_usage:
+                    columnwise_usage = (
+                        is_fp8_activation_recompute_enabled()
+                        and not in_fp8_activation_recompute_phase()
+                    )
+                weight_quantizer.set_usage(rowwise=True, columnwise=columnwise_usage)
+
+            # FP8 cast to workspace buffer
+            update_workspace = is_first_microbatch is None or is_first_microbatch
+            weightmat = module.get_weight_workspace(
+                tensor=weight,
+                quantizer=weight_quantizer,
+                cache_name=(None if is_first_microbatch is None else "weight"),
+                update_workspace=update_workspace,
+                skip_update_flag=skip_fp8_weight_update,
+                fsdp_group=fsdp_group,
+            )
 
         # Cast bias to expected dtype
         bias_dtype = activation_dtype
@@ -259,6 +257,7 @@ class _Linear(torch.autograd.Function):
         nvtx_range_pop(f"{nvtx_label}.gemm")
 
         if is_grad_enabled:
+            ctx.weight_quantizer = weight_quantizer
             saved_inputmat = None
 
             ctx.backward_input_needs_gather = (
@@ -274,6 +273,11 @@ class _Linear(torch.autograd.Function):
                         inputmat.update_usage(rowwise_usage=False)
                 saved_inputmat = inputmat
 
+            # Weight with column-wise usage is needed for dgrad GEMM.
+            if inp.requires_grad:
+                if isinstance(weightmat, QuantizedTensor):
+                    weightmat.update_usage(columnwise_usage=True)
+
             if cpu_offloading:
                 set_offloading_param(weight, "weight_offloading", True)
                 set_offloading_param(weightmat, "weight_offloading", True)
@@ -530,6 +534,12 @@ class _Linear(torch.autograd.Function):
                             recipe.fp8_gemm_dgrad.use_split_accumulator
                         )
 
+                if ctx.weight_quantizer is not None and isinstance(weight_fp8, QuantizedTensor):
+                    weight_fp8.update_usage(
+                        rowwise_usage=ctx.weight_quantizer.rowwise_usage,
+                        columnwise_usage=ctx.weight_quantizer.columnwise_usage,
+                    )
+
                 dgrad, *_, rs_out = general_gemm(
                     weight_fp8,
                     grad_output,
@@ -1077,6 +1087,13 @@ class Linear(TransformerEngineBaseModule):
         if skip_fp8_weight_update is not None:
             is_first_microbatch = False
 
+        if self.ub_overlap_rs_fprop:
+            if get_ub(self.ub_name + "_fprop").is_fp8_ubuf():
+                fp8_output = True
+        if self.ub_overlap_rs_dgrad:
+            if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf():
+                fp8_grad = True
+
         with self.prepare_forward(
             inp,
             allow_non_contiguous=isinstance(inp, QuantizedTensor),
diff --git a/transformer_engine/pytorch/tensor/__init__.py b/transformer_engine/pytorch/tensor/__init__.py
index 610ec2a..22b86fb 100644
--- a/transformer_engine/pytorch/tensor/__init__.py
+++ b/transformer_engine/pytorch/tensor/__init__.py
@@ -7,6 +7,7 @@
 import torch
 
 from .quantized_tensor import QuantizedTensor, Quantizer
+from .utils import cast_master_weights_to_fp8, replace_raw_data
 
 __all__ = [
     "QuantizedTensor",
diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py
index e45010b..2fb1283 100644
--- a/transformer_engine/pytorch/tensor/float8_tensor.py
+++ b/transformer_engine/pytorch/tensor/float8_tensor.py
@@ -185,9 +185,9 @@ class Float8CurrentScalingQuantizer(Quantizer):
 
     """
 
-    """Scaling factor to multiply when quantizing to FP8"""
+    """Workspace buffer for FP8 scaling factor"""
     scale: torch.Tensor
-    """Max-abs value from last FP8 cast"""
+    """Workspace buffer for max-abs value"""
     amax: torch.Tensor
     """FP8 datatype"""
     dtype: TE_DType
diff --git a/transformer_engine/pytorch/tensor/utils.py b/transformer_engine/pytorch/tensor/utils.py
new file mode 100644
index 0000000..8dd04b5
--- /dev/null
+++ b/transformer_engine/pytorch/tensor/utils.py
@@ -0,0 +1,315 @@
+# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
+#
+# See LICENSE for license information.
+
+"""Helper functions for using fp8 tensors as weights"""
+
+import torch
+
+import transformer_engine_torch as tex
+from transformer_engine_torch import multi_tensor_scale, multi_tensor_compute_scale_and_scale_inv
+
+from .quantized_tensor import QuantizedTensor
+from .float8_tensor import Float8Tensor, Float8Quantizer, Float8CurrentScalingQuantizer
+from .mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
+from ..optimizers.multi_tensor_apply import multi_tensor_applier
+
+
+def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor):
+    r"""Change a quantized tensor's data buffer while preserving values
+
+    This function modifies only the address space of the underlying
+    raw data and does not alter any other tensor attributes or values.
+
+    This may be used for custom buffer allocations, e.g. packing
+    multiple parameter tensors together into a single contiguous
+    buffer for ZeRO-2.
+
+    """
+    if isinstance(tensor, Float8Tensor):
+        old_raw_data = tensor._data
+        assert old_raw_data.dtype == new_raw_data.dtype, "The data types of raw data don't match"
+        new_raw_data.detach().copy_(old_raw_data)
+        tensor._data = new_raw_data
+        del old_raw_data
+    elif isinstance(tensor, MXFP8Tensor):
+        raise NotImplementedError("replace_raw_data for MXFP8Tensor is not supported yet")
+    else:
+        raise ValueError(f"replace_raw_data for {type(tensor)} is not supported yet")
+
+
+def cast_master_weights_to_fp8(
+    model_weights, master_weights, start_offsets, group, fsdp_shard_model_weights=None
+):
+    r"""Helper function to cast master weights to FP8 primary weights.
+
+    This is intended for use with ZeRO/FSDP. Each rank has a shard of
+    the master weights (possibly empty) and a full copy of the model
+    weights.
+
+    Parameters
+    ----------
+    model_weights  : list of FP8 weights.
+    master_weights : list of master weights. Typically they are FP32 weights.
+    start_offsets  : list of integers, the starting index of the master weight in the model weight.
+                     master_weight may be smaller than model_weight because it could be distributed
+                     across multiple ranks. These offsets indicate which part of the model_weight
+                     should be updated.
+    group          : The distributed group to do amax reduction. Typically it's the data parallel
+                     group.
+    fsdp_shard_model_weights : list of FSDP shard model weights. If None, it means that the model weights are
+                             not sharded. Otherwise, it means that the model weights are sharded and we get
+                             target model weights data storage using the FSDP shard model weights.
+
+    """
+
+    delayed_scaling_params = []
+    current_scaling_params = []
+
+    if fsdp_shard_model_weights is None:
+        use_fsdp_shard_model_weights = False
+        fsdp_shard_model_weights = [None] * len(model_weights)
+    else:
+        use_fsdp_shard_model_weights = True
+
+    for model_weight, master_weight, start_offset, fsdp_shard_model_weight in zip(
+        model_weights, master_weights, start_offsets, fsdp_shard_model_weights
+    ):
+        # Clear `_high_precision_init_val` of model_weight automatically.
+        # - Master weights are initialized from model weights, if we use fp8 primary weights to
+        #   initialize master weights, the numerical values of master weights are not consistent
+        #   with the numerical values when we initialize them from bf16/fp16 weights.
+        # - So we add a `_high_precision_init_val` attribute to each model weight to store the
+        #   original bf16/fp16 weight on cpu before casting it to fp8. And users can use
+        #   `get_high_precision_init_val` to get this cpu tensor.
+        # - This cpu tensor is not needed once the master weight is initialized, so users should
+        #   call `clear_high_precision_init_val` to remove it after master weight is initialized.
+        # - In case users don't call `clear_high_precision_init_val`, we will clear it automatically
+        #   here. It's safe to clear the `_high_precision_init_val` at this time because this
+        #   function is supposed to be called after the master weights are initialized and updated.
+        if hasattr(model_weight, "clear_high_precision_init_val"):
+            model_weight.clear_high_precision_init_val()
+
+        if master_weight is not None:
+            # When not using fp8_primary_weights, the master_weight (fp32) is first cast to
+            # bf16/fp16, and then cast to fp8 during forward. Although it's not necessary when
+            # fp8_primary_weights is enabled, we still keep this logic to keep numerical
+            # consistency. So here we cast the master_weight to model_weight.dtype.
+            master_weight = master_weight.to(model_weight.dtype)
+
+        quantizer = model_weight._get_quantizer()
+        if isinstance(quantizer, Float8Quantizer):
+            delayed_scaling_params.append(
+                (model_weight, master_weight, start_offset, fsdp_shard_model_weight)
+            )
+        elif isinstance(quantizer, Float8CurrentScalingQuantizer):
+            current_scaling_params.append(
+                (model_weight, master_weight, start_offset, fsdp_shard_model_weight)
+            )
+        elif isinstance(quantizer, MXFP8Quantizer):
+            raise NotImplementedError(
+                "cast_master_weights_to_fp8 for MXFP8BlockScaling is not supported yet"
+            )
+        else:
+            raise ValueError(
+                f"cast_master_weights_to_fp8 for {type(quantizer)} is not supported yet"
+            )
+
+    if len(delayed_scaling_params) > 0:
+        _cast_master_weights_to_fp8_delayed_scaling(
+            delayed_scaling_params, group, use_fsdp_shard_model_weights
+        )
+    if len(current_scaling_params) > 0:
+        _cast_master_weights_to_fp8_current_scaling(
+            current_scaling_params, group, use_fsdp_shard_model_weights
+        )
+
+
+def _cast_master_weights_to_fp8_delayed_scaling(params, group, use_fsdp_shard_model_weights=False):
+    r"""Helper function to cast master weights to FP8 primary weights for delayed scaling.
+
+    Parameters
+    ----------
+    params : List of tuple, each tuple contains a model weight, a master weight, and an offset
+             indicating the starting index of the master weight in the model weight.
+    group  : The distributed group to do amax reduction. Typically it's the data parallel
+             group.
+    use_fsdp_shard_model_weights : bool, if True, it means that the model weights are sharded.
+    """
+
+    # Collect amaxes to do reduce-max among dp group.
+    # Collect scales and scale_invs to update scale_invs of the fp8 weights.
+    amaxes, scales, scale_invs = [], [], []
+
+    for model_weight, master_weight, start_offset, shard_model_weight_raw in params:
+        # Reset transpose cache for all model weights.
+        # We cannot create transpose cache here because users (like megatron) may want to overlap
+        # the all-gather of model weights and forward process, so the model weight is not updated
+        # currently.
+        model_weight._reset_caches()
+
+        quantizer = model_weight._get_quantizer()
+
+        amaxes.append(quantizer.amax.view(1))
+        scales.append(quantizer.scale.view(1))
+        scale_invs.append(model_weight._scale_inv.view(1))
+
+        # If master weight is None, it means that the master weight of the current model weight
+        # is in other DP ranks.
+        if master_weight is None:
+            continue
+
+        # If master weight is not None, start_offset must be a valid value.
+        assert start_offset is not None
+        assert start_offset >= 0
+        end_offset = start_offset + master_weight.numel()
+        assert end_offset <= model_weight.numel()
+
+        # master_weight may be smaller than model_weight because it could be distributed across
+        # multiple ranks. So we need to create a dummy weight using the raw data from model_weight.
+        if not use_fsdp_shard_model_weights:
+            shard_model_weight_raw = model_weight._data.view(-1)[start_offset:end_offset]
+        shard_model_weight_fp8 = quantizer.create_tensor_from_data(
+            shard_model_weight_raw.view(1, -1),
+            model_weight.dtype,
+        )
+
+        # Cast master weight to fp8.
+        quantizer.update_quantized(master_weight.view(1, -1), shard_model_weight_fp8)
+
+    if len(amaxes) > 0:
+        dummy_overflow_buf = torch.tensor([0], dtype=torch.int, device=amaxes[0].device)
+
+        # Reduce amaxes.
+        packed_amaxes = torch.empty(len(amaxes), dtype=torch.float32, device=amaxes[0].device)
+        packed_amax_views = [packed_amaxes[i].view(1) for i in range(len(amaxes))]
+        multi_tensor_applier(
+            multi_tensor_scale, dummy_overflow_buf, [amaxes, packed_amax_views], 1.0
+        )
+        torch.distributed.all_reduce(
+            packed_amaxes,
+            op=torch.distributed.ReduceOp.MAX,
+            group=group,
+        )
+        multi_tensor_applier(
+            multi_tensor_scale, dummy_overflow_buf, [packed_amax_views, amaxes], 1.0
+        )
+
+        # Update scale_invs.
+        packed_scales = torch.empty(len(scales), dtype=torch.float32, device=scales[0].device)
+        packed_scale_views = [packed_scales[i].view(1) for i in range(len(scales))]
+        multi_tensor_applier(
+            multi_tensor_scale, dummy_overflow_buf, [scales, packed_scale_views], 1.0
+        )
+        torch.reciprocal(packed_scales, out=packed_scales)
+        multi_tensor_applier(
+            multi_tensor_scale, dummy_overflow_buf, [packed_scale_views, scale_invs], 1.0
+        )
+
+
+def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_model_weights=False):
+    r"""Helper function to cast master weights to FP8 primary weights for current scaling.
+
+    Parameters
+    ----------
+    params : List of tuple, each tuple contains a model weight, a master weight, and an offset
+             indicating the starting index of the master weight in the model weight.
+    group  : The distributed group to do amax reduction. Typically it's the data parallel
+             group.
+    use_fsdp_shard_model_weights : bool, if True, it means that the model weights are sharded.
+    """
+
+    # Parameter attributes
+    device = params[0][0].device
+    fp8_dtype = params[0][0]._get_quantizer().dtype
+    force_pow_2_scales = params[0][0]._get_quantizer().force_pow_2_scales
+    amax_epsilon = params[0][0]._get_quantizer().amax_epsilon
+
+    # Create a dummy overflow buffer, it's needed by multi_tensor_applier.
+    dummy_overflow_buf = torch.zeros(1, dtype=torch.int, device=device)
+
+    # Create a contiguous buffer to store amaxes temporarily, so we can perform all all-reduce
+    # NCCL kernels at once.
+    packed_amaxes = torch.zeros(len(params), dtype=torch.float32, device=device)
+    amaxes = [packed_amaxes[i : i + 1] for i in range(len(params))]
+
+    # Collect scales and scale_invs to update them after amax reduction.
+    scales, scale_invs = [], []
+
+    # ---------------------------------------------------------------------------------------------
+    # Step 1: Iterate through all the none empty master weights and compute amax of them. Store the
+    #         amaxes in a contiguous buffer. If the master weight is None, the corresponding amax
+    #         will be set to 0.
+    # ---------------------------------------------------------------------------------------------
+    for (model_weight, master_weight, _, _), amax in zip(params, amaxes):
+
+        # Make sure all the model weights have the same numerical options.
+        quantizer = model_weight._get_quantizer()
+        assert quantizer.dtype == fp8_dtype
+        assert quantizer.force_pow_2_scales == force_pow_2_scales
+        assert quantizer.amax_epsilon == amax_epsilon
+
+        scales.append(quantizer.scale.view(1))
+        scale_invs.append(model_weight._scale_inv.view(1))
+
+        # Compute amax of the master weight and store it in packed_amaxes.
+        if master_weight is not None:
+            tex.compute_amax(master_weight, amax)
+
+    # ---------------------------------------------------------------------------------------------
+    # Step 2: Perform all-reduce on packed_amaxes to get the global amax.
+    # ---------------------------------------------------------------------------------------------
+    torch.distributed.all_reduce(packed_amaxes, op=torch.distributed.ReduceOp.MAX, group=group)
+
+    # ---------------------------------------------------------------------------------------------
+    # Step 3: Update scales and scale_invs.
+    # ---------------------------------------------------------------------------------------------
+    if fp8_dtype == tex.DType.kFloat8E4M3:
+        max_fp8 = 448.0
+    elif fp8_dtype == tex.DType.kFloat8E5M2:
+        max_fp8 = 57344.0
+    else:
+        raise ValueError(f"Unsupported FP8 dtype: {fp8_dtype}")
+    multi_tensor_applier(
+        multi_tensor_compute_scale_and_scale_inv,
+        dummy_overflow_buf,
+        [amaxes, scales, scale_invs],
+        max_fp8,
+        force_pow_2_scales,
+        amax_epsilon,
+    )
+
+    # ---------------------------------------------------------------------------------------------
+    # Step 4: Cast master weights to FP8.
+    # ---------------------------------------------------------------------------------------------
+    for (model_weight, master_weight, start_offset, model_weight_fragment), scale in zip(
+        params, scales
+    ):
+        # Reset transpose cache for all model weights.
+        # We cannot create transpose cache here because users (like megatron) may want to overlap
+        # the all-gather of model weights and forward process, so the model weight is not updated
+        # currently.
+        model_weight._reset_caches()
+
+        # If master weight is None, it means that the master weight of the current model weight
+        # is in other DP ranks.
+        if master_weight is None:
+            continue
+
+        # Cast master weight to FP8
+        end_offset = start_offset + master_weight.numel()
+        if not use_fsdp_shard_model_weights:
+            model_weight_fragment = model_weight.reshape(-1)[start_offset:end_offset]
+        quantizer = Float8Quantizer(
+            scale=scale,
+            amax=torch.Tensor(),
+            fp8_dtype=model_weight._fp8_dtype,
+        )
+        if use_fsdp_shard_model_weights and not isinstance(model_weight_fragment, Float8Tensor):
+            # NOTE: The fsdp shard model weight may be a unit8 tensor instead of
+            # a float8 tensor. We should handle this situation properly.
+            model_weight_fragment = quantizer.create_tensor_from_data(
+                model_weight_fragment.view(-1),
+                model_weight.dtype,
+            )
+        quantizer.update_quantized(master_weight, model_weight_fragment)