Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
9df0c4a3
Commit
9df0c4a3
authored
Mar 03, 2026
by
wenjh
Browse files
Merge branch 'nv_main'
parents
0d874a4e
f122b07d
Changes
221
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1691 additions
and
580 deletions
+1691
-580
transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py
transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py
+49
-55
transformer_engine/pytorch/ops/fused/backward_linear_add.py
transformer_engine/pytorch/ops/fused/backward_linear_add.py
+58
-59
transformer_engine/pytorch/ops/fused/backward_linear_scale.py
...sformer_engine/pytorch/ops/fused/backward_linear_scale.py
+54
-55
transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py
...ngine/pytorch/ops/fused/forward_linear_bias_activation.py
+59
-58
transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py
...ormer_engine/pytorch/ops/fused/forward_linear_bias_add.py
+55
-64
transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py
...rmer_engine/pytorch/ops/fused/forward_linear_scale_add.py
+62
-66
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
...r_engine/pytorch/ops/fused/userbuffers_backward_linear.py
+71
-86
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
...er_engine/pytorch/ops/fused/userbuffers_forward_linear.py
+65
-79
transformer_engine/pytorch/ops/fuser.py
transformer_engine/pytorch/ops/fuser.py
+119
-49
transformer_engine/pytorch/optimizers/fused_sgd.py
transformer_engine/pytorch/optimizers/fused_sgd.py
+2
-2
transformer_engine/pytorch/quantized_tensor.py
transformer_engine/pytorch/quantized_tensor.py
+15
-3
transformer_engine/pytorch/tensor/float8_tensor.py
transformer_engine/pytorch/tensor/float8_tensor.py
+17
-2
transformer_engine/pytorch/tensor/mxfp8_tensor.py
transformer_engine/pytorch/tensor/mxfp8_tensor.py
+44
-1
transformer_engine/pytorch/tensor/nvfp4_tensor.py
transformer_engine/pytorch/tensor/nvfp4_tensor.py
+4
-1
transformer_engine/pytorch/tensor/storage/__init__.py
transformer_engine/pytorch/tensor/storage/__init__.py
+1
-0
transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py
...pytorch/tensor/storage/float8_blockwise_tensor_storage.py
+18
-0
transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py
...er_engine/pytorch/tensor/storage/float8_tensor_storage.py
+18
-0
transformer_engine/pytorch/tensor/storage/grouped_tensor.py
transformer_engine/pytorch/tensor/storage/grouped_tensor.py
+942
-0
transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py
...mer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py
+18
-0
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py
...mer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py
+20
-0
No files found.
transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py
View file @
9df0c4a3
...
@@ -42,7 +42,7 @@ class BackwardAddRMSNorm(FusedOperation):
...
@@ -42,7 +42,7 @@ class BackwardAddRMSNorm(FusedOperation):
# Get basic operations
# Get basic operations
rmsnorm_op
=
self
.
basic_ops
[
1
]
rmsnorm_op
=
self
.
basic_ops
[
1
]
rmsnorm_op_ctx
=
basic_op_ctxs
[
0
]
rmsnorm_op_ctx
=
basic_op_ctxs
[
1
]
# Saved tensors from forward pass
# Saved tensors from forward pass
x
,
rstdevs
=
rmsnorm_op_ctx
.
saved_tensors
x
,
rstdevs
=
rmsnorm_op_ctx
.
saved_tensors
...
@@ -53,7 +53,7 @@ class BackwardAddRMSNorm(FusedOperation):
...
@@ -53,7 +53,7 @@ class BackwardAddRMSNorm(FusedOperation):
# Check input tensors
# Check input tensors
dtype
=
rmsnorm_op_ctx
.
dtype
dtype
=
rmsnorm_op_ctx
.
dtype
extra_grad
=
basic_op_grad_extra_outputs
[
1
][
0
]
extra_grad
=
basic_op_grad_extra_outputs
[
0
][
0
]
dy
=
maybe_dequantize
(
grad_output
.
contiguous
(),
dtype
).
view
(
x
.
size
())
dy
=
maybe_dequantize
(
grad_output
.
contiguous
(),
dtype
).
view
(
x
.
size
())
w
=
maybe_dequantize
(
rmsnorm_op
.
weight
,
dtype
).
view
((
inner_dim
,))
w
=
maybe_dequantize
(
rmsnorm_op
.
weight
,
dtype
).
view
((
inner_dim
,))
add
=
maybe_dequantize
(
extra_grad
.
contiguous
(),
dtype
).
view
(
x
.
size
())
add
=
maybe_dequantize
(
extra_grad
.
contiguous
(),
dtype
).
view
(
x
.
size
())
...
@@ -77,57 +77,51 @@ class BackwardAddRMSNorm(FusedOperation):
...
@@ -77,57 +77,51 @@ class BackwardAddRMSNorm(FusedOperation):
grad_input
=
dx
.
view
(
grad_output
.
size
())
grad_input
=
dx
.
view
(
grad_output
.
size
())
grad_weight
=
dw
.
view
(
weight_dims
)
grad_weight
=
dw
.
view
(
weight_dims
)
return
grad_input
,
[(
grad_weight
,),
()],
[(),
()]
return
grad_input
,
[(),
(
grad_weight
,)],
[(),
()]
@
staticmethod
def
fuse_backward_add_rmsnorm
(
def
fuse_backward_ops
(
ops
:
list
[
tuple
[
FusibleOperation
,
list
[
int
]]],
ops
:
list
[
FusibleOperation
],
)
->
list
[
tuple
[
FusibleOperation
,
list
[
int
]]]:
**
unused
,
# pylint: disable=unused-argument
"""Fused backward RMNorm + add
)
->
list
[
FusibleOperation
]:
"""Apply operation fusion for backward pass.
Parameters
----------
Parameters
ops : list of tuples
----------
Backward pass operations and the indices of the corresponding
ops : list of FusibleOperation
basic operations.
Backward pass operations.
Returns
Returns
-------
-------
ops : list of tuples
ops : list of FusibleOperation
Updated backward pass operations
Updated backward pass operations
"""
"""
# Scan through ops, fusing if possible
# Scan through ops, fusing if possible
out
=
[]
out
=
[]
window
=
[]
window
,
ops
=
ops
[:
2
],
ops
[
2
:]
while
len
(
ops
)
>=
2
:
while
len
(
window
)
==
2
:
if
(
isinstance
(
window
[
0
],
MakeExtraOutput
)
and
isinstance
(
window
[
1
],
RMSNorm
)
and
not
window
[
0
].
_in_place
):
# Construct fused op if window matches pattern
op
=
BackwardAddRMSNorm
(
add
=
window
[
0
],
rmsnorm
=
window
[
1
])
window
=
[
op
]
else
:
# Shift window if window doesn't match pattern
out
.
extend
(
window
[:
-
1
])
window
=
window
[
-
1
:]
# Adjust window to expected size
out
.
extend
(
window
[:
-
2
])
window
=
window
[
-
2
:]
while
ops
and
len
(
window
)
<
2
:
window
.
append
(
ops
[
0
])
ops
=
ops
[
1
:]
# Return list of ops
out
.
extend
(
window
)
out
.
extend
(
window
)
return
out
# Check if first op is linear
window
,
ops
=
ops
[:
1
],
ops
[
1
:]
op
,
_
=
window
[
0
]
if
not
isinstance
(
op
,
RMSNorm
):
continue
# Check if second op is "make extra output"
op
,
_
=
ops
[
0
]
if
not
isinstance
(
op
,
MakeExtraOutput
):
continue
if
op
.
_in_place
:
continue
window
.
extend
(
ops
[:
1
])
ops
=
ops
[
1
:]
# Replace window with fused op
op
=
BackwardAddRMSNorm
(
rmsnorm
=
window
[
0
][
0
],
add
=
window
[
1
][
0
],
)
basic_op_idxs
=
[
basic_op_idxs
[
0
]
for
_
,
basic_op_idxs
in
window
]
window
=
[(
op
,
basic_op_idxs
)]
# Return list of ops
out
.
extend
(
window
)
out
.
extend
(
ops
)
return
out
transformer_engine/pytorch/ops/fused/backward_linear_add.py
View file @
9df0c4a3
...
@@ -45,7 +45,7 @@ class BackwardLinearAdd(FusedOperation):
...
@@ -45,7 +45,7 @@ class BackwardLinearAdd(FusedOperation):
# Get basic operations
# Get basic operations
linear_op
=
self
.
basic_ops
[
1
]
linear_op
=
self
.
basic_ops
[
1
]
linear_op_ctx
=
basic_op_ctxs
[
0
]
linear_op_ctx
=
basic_op_ctxs
[
1
]
# Saved tensors from forward pass
# Saved tensors from forward pass
(
x_local
,
w
)
=
linear_op_ctx
.
saved_tensors
(
x_local
,
w
)
=
linear_op_ctx
.
saved_tensors
...
@@ -71,7 +71,7 @@ class BackwardLinearAdd(FusedOperation):
...
@@ -71,7 +71,7 @@ class BackwardLinearAdd(FusedOperation):
accumulate_into_main_grad
=
False
accumulate_into_main_grad
=
False
# Linear backward pass
# Linear backward pass
grad_input
=
basic_op_grad_extra_outputs
[
1
][
0
]
grad_input
=
basic_op_grad_extra_outputs
[
0
][
0
]
grad_input
,
grad_weight
=
BasicLinear
.
_functional_backward
(
grad_input
,
grad_weight
=
BasicLinear
.
_functional_backward
(
grad_output
=
grad_output
,
grad_output
=
grad_output
,
input
=
x_local
,
input
=
x_local
,
...
@@ -109,61 +109,60 @@ class BackwardLinearAdd(FusedOperation):
...
@@ -109,61 +109,60 @@ class BackwardLinearAdd(FusedOperation):
zero
=
getattr
(
weight_param
,
"zero_out_wgrad"
,
False
),
zero
=
getattr
(
weight_param
,
"zero_out_wgrad"
,
False
),
)
)
return
grad_input
,
[(
grad_weight
,),
()],
[(),
()]
return
grad_input
,
[(),
(
grad_weight
,)],
[(),
()]
@
staticmethod
def
fuse_backward_linear_add
(
def
fuse_backward_ops
(
ops
:
list
[
tuple
[
FusibleOperation
,
list
[
int
]]],
ops
:
list
[
FusibleOperation
],
)
->
list
[
tuple
[
FusibleOperation
,
list
[
int
]]]:
**
unused
,
# pylint: disable=unused-argument
"""Fused backward dgrad GEMM + add
)
->
list
[
FusibleOperation
]:
"""Apply operation fusion for backward pass.
Parameters
----------
Parameters
ops : list of tuples
----------
Backward pass operations and the indices of the corresponding
ops : list of FusibleOperation
basic operations.
Backward pass operations.
Returns
Returns
-------
-------
ops : list of tuples
ops : list of FusibleOperation
Updated backward pass operations
Updated backward pass operations
"""
"""
# Scan through ops, fusing if possible
# Scan through ops, fusing if possible
out
=
[]
out
=
[]
window
=
[]
window
,
ops
=
ops
[:
2
],
ops
[
2
:]
while
len
(
ops
)
>=
2
:
while
len
(
window
)
==
2
:
# Check if window matches pattern
matches_pattern
=
True
if
not
(
isinstance
(
window
[
0
],
MakeExtraOutput
)
and
isinstance
(
window
[
1
],
BasicLinear
)):
matches_pattern
=
False
elif
not
window
[
0
].
_in_place
:
# Fused op accumulates grad input in-place
matches_pattern
=
False
elif
window
[
1
].
tensor_parallel_mode
==
"column"
:
# Column tensor-parallelism requires communication
# after the dgrad GEMM
matches_pattern
=
False
if
matches_pattern
:
# Construct fused op if window matches pattern
op
=
BackwardLinearAdd
(
backward_add
=
window
[
0
],
linear
=
window
[
1
])
window
=
[
op
]
else
:
# Shift window if window doesn't match pattern
out
.
extend
(
window
[:
-
1
])
window
=
window
[
-
1
:]
# Adjust window to expected size
out
.
extend
(
window
[:
-
2
])
window
=
window
[
-
2
:]
while
ops
and
len
(
window
)
<
2
:
window
.
append
(
ops
[
0
])
ops
=
ops
[
1
:]
# Return list of ops
out
.
extend
(
window
)
out
.
extend
(
window
)
return
out
# Check if first op is linear
window
,
ops
=
ops
[:
1
],
ops
[
1
:]
op
,
_
=
window
[
0
]
if
not
isinstance
(
op
,
BasicLinear
):
continue
if
op
.
tensor_parallel_mode
==
"column"
:
# Row tensor-parallelism requires communication after the
# GEMM
continue
# Check if second op is "make extra output"
op
,
_
=
ops
[
0
]
if
not
isinstance
(
op
,
MakeExtraOutput
):
continue
if
not
op
.
_in_place
:
continue
window
.
extend
(
ops
[:
1
])
ops
=
ops
[
1
:]
# Replace window with fused op
op
=
BackwardLinearAdd
(
linear
=
window
[
0
][
0
],
backward_add
=
window
[
1
][
0
],
)
basic_op_idxs
=
[
basic_op_idxs
[
0
]
for
_
,
basic_op_idxs
in
window
]
window
=
[(
op
,
basic_op_idxs
)]
# Return list of ops
out
.
extend
(
window
)
out
.
extend
(
ops
)
return
out
transformer_engine/pytorch/ops/fused/backward_linear_scale.py
View file @
9df0c4a3
...
@@ -45,7 +45,7 @@ class BackwardLinearScale(FusedOperation):
...
@@ -45,7 +45,7 @@ class BackwardLinearScale(FusedOperation):
# Get basic operations
# Get basic operations
linear_op
=
self
.
basic_ops
[
0
]
linear_op
=
self
.
basic_ops
[
0
]
linear_op_ctx
=
basic_op_ctxs
[
1
]
linear_op_ctx
=
basic_op_ctxs
[
0
]
scale_op
=
self
.
basic_ops
[
1
]
scale_op
=
self
.
basic_ops
[
1
]
# Saved tensors from forward pass
# Saved tensors from forward pass
...
@@ -109,58 +109,57 @@ class BackwardLinearScale(FusedOperation):
...
@@ -109,58 +109,57 @@ class BackwardLinearScale(FusedOperation):
zero
=
getattr
(
weight_param
,
"zero_out_wgrad"
,
False
),
zero
=
getattr
(
weight_param
,
"zero_out_wgrad"
,
False
),
)
)
return
grad_input
,
[(),
(
grad_weight
,)],
[(),
()]
return
grad_input
,
[(
grad_weight
,),
()],
[(),
()]
@
staticmethod
def
fuse_backward_linear_scale
(
def
fuse_backward_ops
(
ops
:
list
[
tuple
[
FusibleOperation
,
list
[
int
]]],
ops
:
list
[
FusibleOperation
],
)
->
list
[
tuple
[
FusibleOperation
,
list
[
int
]]]:
**
unused
,
# pylint: disable=unused-argument
"""Fused backward dgrad GEMM + constant scale
)
->
list
[
FusibleOperation
]:
"""Apply operation fusion for backward pass.
Parameters
----------
Parameters
ops : list of tuples
----------
Backward pass operations and the indices of the corresponding
ops : list of FusibleOperation
basic operations.
Backward pass operations.
Returns
Returns
-------
-------
ops : list of tuples
ops : list of FusibleOperation
Updated backward pass operations
Updated backward pass operations
"""
"""
# Scan through ops, fusing if possible
# Scan through ops, fusing if possible
out
=
[]
out
=
[]
window
=
[]
window
,
ops
=
ops
[:
2
],
ops
[
2
:]
while
len
(
ops
)
>=
2
:
while
len
(
window
)
==
2
:
# Check if window matches pattern
matches_pattern
=
True
if
not
(
isinstance
(
window
[
0
],
BasicLinear
)
and
isinstance
(
window
[
1
],
ConstantScale
)):
matches_pattern
=
False
elif
window
[
0
].
tensor_parallel_mode
==
"column"
:
# Column tensor-parallelism requires communication
# after the dgrad GEMM
matches_pattern
=
False
if
matches_pattern
:
# Construct fused op if window matches pattern
op
=
BackwardLinearScale
(
linear
=
window
[
0
],
scale
=
window
[
1
])
window
=
[
op
]
else
:
# Shift window if window doesn't match pattern
out
.
extend
(
window
[:
-
1
])
window
=
window
[
-
1
:]
# Adjust window to expected size
out
.
extend
(
window
[:
-
2
])
window
=
window
[
-
2
:]
while
ops
and
len
(
window
)
<
2
:
window
.
append
(
ops
[
0
])
ops
=
ops
[
1
:]
# Return list of ops
out
.
extend
(
window
)
out
.
extend
(
window
)
return
out
# Check if first op is constant scale
window
,
ops
=
ops
[:
1
],
ops
[
1
:]
op
,
_
=
window
[
0
]
if
not
isinstance
(
op
,
ConstantScale
):
continue
# Check if second op is linear
op
,
_
=
ops
[
0
]
if
not
isinstance
(
op
,
BasicLinear
):
continue
if
op
.
tensor_parallel_mode
==
"column"
:
# Column tensor-parallelism requires communication after the dgrad GEMM
continue
window
.
extend
(
ops
[:
1
])
ops
=
ops
[
1
:]
# Replace window with fused op
op
=
BackwardLinearScale
(
scale
=
window
[
0
][
0
],
linear
=
window
[
1
][
0
],
)
basic_op_idxs
=
[
basic_op_idxs
[
0
]
for
_
,
basic_op_idxs
in
window
]
window
=
[(
op
,
basic_op_idxs
)]
# Return list of ops
out
.
extend
(
window
)
out
.
extend
(
ops
)
return
out
transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py
View file @
9df0c4a3
...
@@ -134,62 +134,63 @@ class ForwardLinearBiasActivation(FusedOperation):
...
@@ -134,62 +134,63 @@ class ForwardLinearBiasActivation(FusedOperation):
return
output
,
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
return
output
,
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
@
staticmethod
def
fuse_forward_linear_bias_activation
(
def
fuse_forward_ops
(
ops
:
list
[
tuple
[
FusibleOperation
,
list
[
int
]]],
ops
:
list
[
FusibleOperation
],
)
->
list
[
tuple
[
FusibleOperation
,
list
[
int
]]]:
**
unused
,
# pylint: disable=unused-argument
"""Fuse forward GEMM + bias + activation
)
->
list
[
FusibleOperation
]:
"""Apply operation fusion for forward pass.
Parameters
----------
Parameters
ops : list of tuples
----------
Forward pass operations and the indices of the corresponding
ops : list of FusibleOperation
basic operations.
Forward pass operations.
Returns
Returns
-------
-------
ops : list of tuples
ops : list of FusibleOperation
Updated forward pass operations
Updated forward pass operations
"""
"""
# Scan through ops, fusing if possible
# Scan through ops, fusing if possible
out
=
[]
out
=
[]
window
=
[]
window
,
ops
=
ops
[:
2
],
ops
[
2
:]
while
len
(
ops
)
>=
2
:
while
len
(
window
)
==
2
:
# Check if window matches pattern
matches_pattern
=
True
if
not
(
isinstance
(
window
[
0
],
BasicLinear
)
and
isinstance
(
window
[
1
],
Bias
)):
matches_pattern
=
False
elif
window
[
0
].
tensor_parallel_mode
==
"row"
:
# Row tensor-parallelism requires communication after
# the GEMM
matches_pattern
=
False
elif
window
[
0
].
weight
.
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
# cuBLAS only supports fused GEMM+bias+activation with
# FP16 and BF16 output
matches_pattern
=
False
if
matches_pattern
:
# Construct fused op if window matches pattern
op
=
ForwardLinearBiasActivation
(
linear
=
window
[
0
],
bias
=
window
[
1
],
activation
=
None
,
)
window
=
[
op
]
else
:
# Shift window if window doesn't match pattern
out
.
extend
(
window
[:
-
1
])
window
=
window
[
-
1
:]
# Adjust window to expected size
out
.
extend
(
window
[:
-
2
])
window
=
window
[
-
2
:]
while
ops
and
len
(
window
)
<
2
:
window
.
append
(
ops
[
0
])
ops
=
ops
[
1
:]
# Return list of ops
out
.
extend
(
window
)
out
.
extend
(
window
)
return
out
# Check if first op is linear
window
,
ops
=
ops
[:
1
],
ops
[
1
:]
op1
,
_
=
window
[
0
]
if
not
isinstance
(
op1
,
BasicLinear
):
continue
if
op1
.
tensor_parallel_mode
==
"row"
:
# Row tensor-parallelism requires communication after the
# GEMM
continue
if
op1
.
weight
.
dtype
not
in
(
torch
.
float16
,
torch
.
bfloat16
):
# cuBLAS only supports fused GEMM+bias+activation with
# FP16 and BF16 output
continue
# Check if second op is bias
op2
,
_
=
ops
[
0
]
if
not
isinstance
(
op2
,
Bias
):
continue
window
.
extend
(
ops
[:
1
])
ops
=
ops
[
1
:]
# Replace window with fused op
op
=
ForwardLinearBiasActivation
(
linear
=
window
[
0
][
0
],
bias
=
window
[
1
][
0
],
activation
=
None
,
)
basic_op_idxs
=
[
basic_op_idxs
[
0
]
for
_
,
basic_op_idxs
in
window
]
window
=
[(
op
,
basic_op_idxs
)]
# Return list of ops
out
.
extend
(
window
)
out
.
extend
(
ops
)
return
out
transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py
View file @
9df0c4a3
...
@@ -131,72 +131,63 @@ class ForwardLinearBiasAdd(FusedOperation):
...
@@ -131,72 +131,63 @@ class ForwardLinearBiasAdd(FusedOperation):
return
output
,
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
return
output
,
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
@
staticmethod
def
fuse_forward_ops
(
ops
:
list
[
FusibleOperation
],
**
unused
,
# pylint: disable=unused-argument
)
->
list
[
FusibleOperation
]:
"""Apply operation fusion for forward pass.
Parameters
----------
ops : list of FusibleOperation
Forward pass operations.
Returns
-------
ops : list of FusibleOperation
Updated forward pass operations
"""
# Scan through ops, fusing if possible
out
=
[]
window
=
[]
while
ops
:
# Shift window
out
.
extend
(
window
)
window
=
[
ops
[
0
]]
ops
=
ops
[
1
:]
def
fuse_forward_linear_bias_add
(
# Check if first op is linear
ops
:
list
[
tuple
[
FusibleOperation
,
list
[
int
]]],
if
not
isinstance
(
window
[
0
],
BasicLinear
):
)
->
list
[
tuple
[
FusibleOperation
,
list
[
int
]]]:
continue
"""Fuse forward GEMM + bias + add
if
window
[
0
].
tensor_parallel_mode
==
"row"
:
# Row tensor-parallelism requires communication after
Parameters
# the GEMM
----------
continue
ops : list of tuples
linear
=
window
[
0
]
Forward pass operations and the indices of the corresponding
basic operations.
Returns
# Check if next op is bias
-------
bias
=
None
ops : list of tuples
if
ops
and
isinstance
(
ops
[
0
],
Bias
):
Updated forward pass operations
window
.
append
(
ops
[
0
])
ops
=
ops
[
1
:]
bias
=
window
[
-
1
]
# Check if next op is in-place add extra input
if
ops
and
isinstance
(
ops
[
0
],
AddExtraInput
)
and
ops
[
0
].
_in_place
:
window
.
append
(
ops
[
0
])
ops
=
ops
[
1
:]
add
=
window
[
-
1
]
else
:
continue
"""
# Replace window with fused op
op
=
ForwardLinearBiasAdd
(
linear
=
linear
,
bias
=
bias
,
add
=
add
)
window
=
[
op
]
# Scan through ops, fusing if possible
# Return list of ops
out
=
[]
window
=
[]
while
len
(
ops
)
>=
2
:
out
.
extend
(
window
)
out
.
extend
(
window
)
return
out
# Check if first op is linear
window
,
ops
=
ops
[:
1
],
ops
[
1
:]
op
,
_
=
window
[
0
]
if
not
isinstance
(
op
,
BasicLinear
):
continue
if
op
.
tensor_parallel_mode
==
"row"
:
# Row tensor-parallelism requires communication after the
# GEMM
continue
linear
=
op
op
,
_
=
ops
[
0
]
# Check if next op is bias
bias
=
None
if
isinstance
(
op
,
Bias
):
bias
=
op
window
.
extend
(
ops
[:
1
])
ops
=
ops
[
1
:]
if
len
(
ops
)
==
0
:
continue
op
,
_
=
ops
[
0
]
# Check if next op is in-place add extra input
if
not
isinstance
(
op
,
AddExtraInput
):
continue
if
not
op
.
_in_place
:
continue
add
=
op
window
.
extend
(
ops
[:
1
])
ops
=
ops
[
1
:]
# Replace window with fused op
op
=
ForwardLinearBiasAdd
(
linear
=
linear
,
bias
=
bias
,
add
=
add
,
)
basic_op_idxs
=
[
basic_op_idxs
[
0
]
for
_
,
basic_op_idxs
in
window
]
window
=
[(
op
,
basic_op_idxs
)]
# Return list of ops
out
.
extend
(
window
)
out
.
extend
(
ops
)
return
out
transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py
View file @
9df0c4a3
...
@@ -110,70 +110,66 @@ class ForwardLinearScaleAdd(FusedOperation):
...
@@ -110,70 +110,66 @@ class ForwardLinearScaleAdd(FusedOperation):
return
output
,
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
return
output
,
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
@
staticmethod
def
fuse_forward_linear_scale_add
(
def
fuse_forward_ops
(
ops
:
list
[
tuple
[
FusibleOperation
,
list
[
int
]]],
ops
:
list
[
FusibleOperation
],
)
->
list
[
tuple
[
FusibleOperation
,
list
[
int
]]]:
**
unused
,
# pylint: disable=unused-argument
"""Fuse forward GEMM + scale + add
)
->
list
[
FusibleOperation
]:
"""Apply operation fusion for forward pass.
Parameters
----------
Parameters
ops : list of tuples
----------
Forward pass operations and the indices of the corresponding
ops : list of FusibleOperation
basic operations.
Forward pass operations.
Returns
Returns
-------
-------
ops : list of tuples
ops : list of FusibleOperation
Updated forward pass operations
Updated forward pass operations
"""
"""
# Scan through ops, fusing if possible
# Scan through ops, fusing if possible
out
=
[]
out
=
[]
window
=
[]
window
,
ops
=
ops
[:
3
],
ops
[
3
:]
while
len
(
ops
)
>=
3
:
while
len
(
window
)
==
3
:
# Check if window matches pattern
matches_pattern
=
True
if
not
(
isinstance
(
window
[
0
],
BasicLinear
)
and
isinstance
(
window
[
1
],
ConstantScale
)
and
isinstance
(
window
[
2
],
AddExtraInput
)
):
matches_pattern
=
False
elif
window
[
0
].
tensor_parallel_mode
==
"row"
:
# Row tensor-parallelism requires communication after
# the GEMM
matches_pattern
=
False
elif
not
window
[
2
].
_in_place
:
# Fused op accumulates output in-place
matches_pattern
=
False
if
matches_pattern
:
# Construct fused op if window matches pattern
op
=
ForwardLinearScaleAdd
(
linear
=
window
[
0
],
scale
=
window
[
1
],
add
=
window
[
2
],
)
window
=
[
op
]
else
:
# Shift window if window doesn't match pattern
out
.
extend
(
window
[:
-
2
])
window
=
window
[
-
2
:]
# Adjust window to expected size
out
.
extend
(
window
[:
-
3
])
window
=
window
[
-
3
:]
while
ops
and
len
(
window
)
<
3
:
window
.
append
(
ops
[
0
])
ops
=
ops
[
1
:]
# Return list of ops
out
.
extend
(
window
)
out
.
extend
(
window
)
return
out
# Check if first op is linear
window
,
ops
=
ops
[:
1
],
ops
[
1
:]
op
,
_
=
window
[
0
]
if
not
isinstance
(
op
,
BasicLinear
):
continue
if
op
.
tensor_parallel_mode
==
"row"
:
# Row tensor-parallelism requires communication after the
# GEMM
continue
linear
=
op
op
,
_
=
ops
[
0
]
# Check if next op is constant scale
if
not
isinstance
(
op
,
ConstantScale
):
continue
scale
=
op
window
.
extend
(
ops
[:
1
])
ops
=
ops
[
1
:]
op
,
_
=
ops
[
0
]
# Check if next op is in-place add extra input
if
not
isinstance
(
op
,
AddExtraInput
):
continue
if
not
op
.
_in_place
:
continue
add
=
op
window
.
extend
(
ops
[:
1
])
ops
=
ops
[
1
:]
# Replace window with fused op
op
=
ForwardLinearScaleAdd
(
linear
=
linear
,
scale
=
scale
,
add
=
add
,
)
basic_op_idxs
=
[
basic_op_idxs
[
0
]
for
_
,
basic_op_idxs
in
window
]
window
=
[(
op
,
basic_op_idxs
)]
# Return list of ops
out
.
extend
(
window
)
out
.
extend
(
ops
)
return
out
transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py
View file @
9df0c4a3
...
@@ -503,7 +503,7 @@ class UserbuffersBackwardLinear(FusedOperation):
...
@@ -503,7 +503,7 @@ class UserbuffersBackwardLinear(FusedOperation):
# Get basic operations
# Get basic operations
idx
=
self
.
_op_idxs
[
"linear"
]
idx
=
self
.
_op_idxs
[
"linear"
]
linear_op
=
self
.
basic_ops
[
idx
]
linear_op
=
self
.
basic_ops
[
idx
]
linear_op_ctx
=
basic_op_ctxs
[
-
1
]
linear_op_ctx
=
basic_op_ctxs
[
0
]
bias_op
=
None
bias_op
=
None
if
self
.
_op_idxs
[
"bias"
]
is
not
None
:
if
self
.
_op_idxs
[
"bias"
]
is
not
None
:
idx
=
self
.
_op_idxs
[
"bias"
]
idx
=
self
.
_op_idxs
[
"bias"
]
...
@@ -578,99 +578,84 @@ class UserbuffersBackwardLinear(FusedOperation):
...
@@ -578,99 +578,84 @@ class UserbuffersBackwardLinear(FusedOperation):
grad_params
[
self
.
_op_idxs
[
"linear"
]]
=
(
grad_weight
,)
grad_params
[
self
.
_op_idxs
[
"linear"
]]
=
(
grad_weight
,)
if
bias_op
is
not
None
:
if
bias_op
is
not
None
:
grad_params
[
self
.
_op_idxs
[
"bias"
]]
=
(
grad_bias
,)
grad_params
[
self
.
_op_idxs
[
"bias"
]]
=
(
grad_bias
,)
grad_params
.
reverse
()
grad_extra_inputs
=
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
grad_extra_inputs
=
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
return
grad_input
,
grad_params
,
grad_extra_inputs
return
grad_input
,
grad_params
,
grad_extra_inputs
@
staticmethod
def
fuse_backward_ops
(
ops
:
list
[
FusibleOperation
],
**
unused
,
# pylint: disable=unused-argument
)
->
list
[
FusibleOperation
]:
"""Apply operation fusion for backward pass.
def
fuse_userbuffers_backward_linear
(
Parameters
ops
:
list
[
tuple
[
FusibleOperation
,
list
[
int
]]],
----------
)
->
list
[
tuple
[
FusibleOperation
,
list
[
int
]]]:
ops : list of FusibleOperation
"""Substitute linear operations with Userbuffers implementation
Backward pass operations.
recipe : Recipe, optional
Quantization recipe.
Parameters
Returns
----------
-------
ops : list of tuples
ops : list of FusibleOperation
Backward pass operations and the indices of the corresponding
Updated backward pass operations
basic operations.
Returns
"""
-------
ops : list of tuples
Updated backward pass operations
"""
# Return immediately if environment is not distributed
if
not
torch
.
distributed
.
is_initialized
()
or
torch
.
distributed
.
get_world_size
()
==
1
:
return
ops
# Return immediately if environment is not distributed
# Scan through ops, fusing if possible
if
not
torch
.
distributed
.
is_initialized
()
or
torch
.
distributed
.
get_world_size
()
==
1
:
out
=
[]
return
ops
window
=
[]
while
ops
:
# Sliding window in list of ops
window
=
[]
# Shift window
out
.
extend
(
window
)
def
peek_next_op
()
->
Optional
[
FusibleOperation
]:
window
,
ops
=
ops
[:
1
],
ops
[
1
:]
"""Get next op in list of ops"""
nonlocal
ops
# Check if first op is linear
if
not
ops
:
if
not
isinstance
(
window
[
0
],
BasicLinear
):
return
None
return
ops
[
-
1
][
0
]
def
pop_next_op
()
->
FusibleOperation
:
"""Remove next op from list of ops and add to sliding window"""
nonlocal
ops
,
window
window
.
insert
(
0
,
ops
[
-
1
])
ops
=
ops
[:
-
1
]
return
window
[
0
][
0
]
# Scan through ops in reverse order, fusing if possible
out_reversed
=
[]
while
ops
:
out_reversed
.
extend
(
reversed
(
window
))
window
.
clear
()
# Check if next op is linear
next_op
=
pop_next_op
()
if
not
isinstance
(
next_op
,
BasicLinear
):
continue
linear
=
next_op
if
linear
.
_userbuffers_options
is
None
:
continue
# Check if next op is bias
bias
=
None
if
linear
.
tensor_parallel_mode
!=
"row"
and
isinstance
(
peek_next_op
(),
Bias
):
bias
=
pop_next_op
()
# Check if next op is reduce-scatter
reduce_scatter
=
None
if
linear
.
tensor_parallel_mode
is
None
and
isinstance
(
peek_next_op
(),
ReduceScatter
):
reduce_scatter
=
pop_next_op
()
# Check for invalid combinations
if
reduce_scatter
is
None
:
if
linear
.
tensor_parallel_mode
is
None
:
continue
if
linear
.
tensor_parallel_size
==
1
:
continue
if
linear
.
tensor_parallel_mode
==
"row"
and
bias
is
not
None
:
continue
else
:
if
linear
.
tensor_parallel_mode
is
not
None
:
continue
continue
if
reduce_scatter
.
process_group_size
==
1
:
linear
=
window
[
0
]
if
linear
.
_userbuffers_options
is
None
:
continue
continue
# Replace window with fused op
# Check if next op is bias
op
=
UserbuffersBackwardLinear
(
bias
=
None
linear
=
linear
,
if
linear
.
tensor_parallel_mode
!=
"row"
and
ops
and
isinstance
(
ops
[
0
],
Bias
):
bias
=
bias
,
bias
,
ops
=
ops
[
0
],
ops
[
1
:]
reduce_scatter
=
reduce_scatter
,
window
.
append
(
bias
)
)
basic_op_idxs
=
[
basic_op_idxs
[
0
]
for
_
,
basic_op_idxs
in
window
]
# Check if next op is reduce-scatter
window
=
[(
op
,
basic_op_idxs
)]
reduce_scatter
=
None
if
linear
.
tensor_parallel_mode
is
None
and
ops
and
isinstance
(
ops
[
0
],
ReduceScatter
):
# Return list of ops
reduce_scatter
,
ops
=
ops
[
0
],
ops
[
1
:]
out_reversed
.
extend
(
reversed
(
window
))
window
.
append
(
reduce_scatter
)
out
=
out_reversed
out
.
reverse
()
# Check for invalid combinations
return
out
if
reduce_scatter
is
None
:
if
linear
.
tensor_parallel_mode
is
None
:
continue
if
linear
.
tensor_parallel_size
==
1
:
continue
if
linear
.
tensor_parallel_mode
==
"row"
and
bias
is
not
None
:
continue
else
:
if
linear
.
tensor_parallel_mode
is
not
None
:
continue
if
reduce_scatter
.
process_group_size
==
1
:
continue
# Replace window with fused op
op
=
UserbuffersBackwardLinear
(
linear
=
linear
,
bias
=
bias
,
reduce_scatter
=
reduce_scatter
,
)
window
=
[
op
]
# Return list of ops
out
.
extend
(
window
)
return
out
transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py
View file @
9df0c4a3
...
@@ -369,93 +369,79 @@ class UserbuffersForwardLinear(FusedOperation):
...
@@ -369,93 +369,79 @@ class UserbuffersForwardLinear(FusedOperation):
return
output
,
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
return
output
,
[()
for
_
in
range
(
len
(
self
.
basic_ops
))]
@
staticmethod
def
fuse_forward_ops
(
ops
:
list
[
FusibleOperation
],
**
unused
,
# pylint: disable=unused-argument
)
->
list
[
FusibleOperation
]:
"""Apply operation fusion for forward pass.
def
fuse_userbuffers_forward_linear
(
Parameters
ops
:
list
[
tuple
[
FusibleOperation
,
list
[
int
]]],
----------
)
->
list
[
tuple
[
FusibleOperation
,
list
[
int
]]]:
ops : list of FusibleOperation
"""Substitute linear operations with Userbuffers implementation
Forward pass operations.
Parameters
----------
ops : list of tuples
Forward pass operations and the indices of the corresponding
basic operations.
Returns
-------
ops : list of tuples
Updated forward pass operations
"""
Returns
-------
ops : list of FusibleOperation
Updated forward pass operations
# Return immediately if environment is not distributed
"""
if
not
torch
.
distributed
.
is_initialized
()
or
torch
.
distributed
.
get_world_size
()
==
1
:
return
ops
# Sliding window in list of ops
window
=
[]
def
peek_next_op
()
->
Optional
[
FusibleOperation
]:
"""Get next op in list of ops"""
nonlocal
ops
if
not
ops
:
return
None
return
ops
[
0
][
0
]
def
pop_next_op
()
->
FusibleOperation
:
"""Remove next op from list of ops and add to sliding window"""
nonlocal
ops
,
window
window
.
append
(
ops
[
0
])
ops
=
ops
[
1
:]
return
window
[
-
1
][
0
]
# Scan through ops, fusing if possible
out
=
[]
while
ops
:
out
.
extend
(
window
)
window
.
clear
()
# Check if next op is linear
# Return immediately if environment is not distributed
next_op
=
pop_next_op
()
if
not
torch
.
distributed
.
is_initialized
()
or
torch
.
distributed
.
get_world_size
()
==
1
:
if
not
isinstance
(
next_op
,
BasicLinear
):
return
ops
continue
linear
=
next_op
if
linear
.
_userbuffers_options
is
None
:
continue
#
Check if next op is bias
#
Scan through ops, fusing if possible
bias
=
None
out
=
[]
if
linear
.
tensor_parallel_mode
!=
"row"
and
isinstance
(
peek_next_op
(),
Bias
):
window
=
[]
bias
=
pop_next_op
()
while
ops
:
# Check if next op is reduce-scatter
# Shift window
reduce_scatter
=
None
out
.
extend
(
window
)
if
linear
.
tensor_parallel_mode
is
None
and
isinstance
(
peek_next_op
(),
ReduceScatter
):
window
,
ops
=
ops
[:
1
],
ops
[
1
:]
reduce_scatter
=
pop_next_op
()
# Check for invalid combinations
# Check if first op is linear
if
reduce_scatter
is
None
:
if
not
isinstance
(
window
[
0
],
BasicLinear
):
if
linear
.
tensor_parallel_mode
is
None
:
continue
if
linear
.
tensor_parallel_size
==
1
:
continue
if
linear
.
tensor_parallel_mode
==
"row"
and
bias
is
not
None
:
continue
else
:
if
linear
.
tensor_parallel_mode
is
not
None
:
continue
continue
if
reduce_scatter
.
process_group_size
==
1
:
linear
=
window
[
0
]
if
linear
.
_userbuffers_options
is
None
:
continue
continue
# Replace window with fused op
# Check if next op is bias
op
=
UserbuffersForwardLinear
(
bias
=
None
linear
=
linear
,
if
linear
.
tensor_parallel_mode
!=
"row"
and
ops
and
isinstance
(
ops
[
0
],
Bias
):
bias
=
bias
,
bias
,
ops
=
ops
[
0
],
ops
[
1
:]
reduce_scatter
=
reduce_scatter
,
window
.
append
(
bias
)
)
basic_op_idxs
=
[
basic_op_idxs
[
0
]
for
_
,
basic_op_idxs
in
window
]
# Check if next op is reduce-scatter
window
=
[(
op
,
basic_op_idxs
)]
reduce_scatter
=
None
if
linear
.
tensor_parallel_mode
is
None
and
ops
and
isinstance
(
ops
[
0
],
ReduceScatter
):
reduce_scatter
,
ops
=
ops
[
0
],
ops
[
1
:]
window
.
append
(
reduce_scatter
)
# Check for invalid combinations
if
reduce_scatter
is
None
:
if
linear
.
tensor_parallel_mode
is
None
:
continue
if
linear
.
tensor_parallel_size
==
1
:
continue
if
linear
.
tensor_parallel_mode
==
"row"
and
bias
is
not
None
:
continue
else
:
if
linear
.
tensor_parallel_mode
is
not
None
:
continue
if
reduce_scatter
.
process_group_size
==
1
:
continue
# Replace window with fused op
op
=
UserbuffersForwardLinear
(
linear
=
linear
,
bias
=
bias
,
reduce_scatter
=
reduce_scatter
,
)
window
=
[
op
]
# Return list of ops
# Return list of ops
out
.
extend
(
window
)
out
.
extend
(
window
)
return
out
return
out
transformer_engine/pytorch/ops/fuser.py
View file @
9df0c4a3
...
@@ -5,33 +5,20 @@
...
@@ -5,33 +5,20 @@
"""Manager class for a pipeline of fusible operations."""
"""Manager class for a pipeline of fusible operations."""
from
__future__
import
annotations
from
__future__
import
annotations
from
collections.abc
import
Callable
,
Iterable
from
collections.abc
import
Callable
,
Iterable
,
Sequence
from
typing
import
Any
,
Optional
import
itertools
import
itertools
from
typing
import
Any
,
Optional
,
TypeAlias
import
torch
import
torch
from
transformer_engine.pytorch.quantization
import
FP8GlobalStateManager
,
Recipe
,
DelayedScaling
from
..quantization
import
FP8GlobalStateManager
,
Recipe
,
DelayedScaling
from
transformer_engine.pytorch.ops.op
import
(
from
..quantized_tensor
import
prepare_for_saving
,
restore_from_saved
from
.op
import
(
BasicOperation
,
BasicOperation
,
FusibleOperation
,
FusibleOperation
,
FusedOperation
,
OperationContext
,
OperationContext
,
)
)
from
transformer_engine.pytorch.ops.fused
import
(
fuse_backward_activation_bias
,
fuse_backward_add_rmsnorm
,
fuse_backward_linear_add
,
fuse_backward_linear_scale
,
fuse_forward_linear_bias_activation
,
fuse_forward_linear_bias_add
,
fuse_forward_linear_scale_add
,
fuse_userbuffers_backward_linear
,
fuse_userbuffers_forward_linear
,
)
from
transformer_engine.pytorch.quantized_tensor
import
(
prepare_for_saving
,
restore_from_saved
,
)
def
_split_tuple
(
t
:
tuple
,
idx
:
int
)
->
tuple
[
tuple
,
tuple
]:
def
_split_tuple
(
t
:
tuple
,
idx
:
int
)
->
tuple
[
tuple
,
tuple
]:
...
@@ -57,6 +44,12 @@ def _is_graph_capturing() -> bool:
...
@@ -57,6 +44,12 @@ def _is_graph_capturing() -> bool:
return
_is_graph_capturing_function
()
return
_is_graph_capturing_function
()
# Type alias for a function that may perform operation fusion
OperationFusionFunction
:
TypeAlias
=
(
"Callable[tuple[list[FusibleOperation], ...], list[FusibleOperation]]"
)
class
_OperationFuserAutogradFunction
(
torch
.
autograd
.
Function
):
class
_OperationFuserAutogradFunction
(
torch
.
autograd
.
Function
):
"""Autograd function for a pipeline of operations
"""Autograd function for a pipeline of operations
...
@@ -241,7 +234,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
...
@@ -241,7 +234,7 @@ class _OperationFuserAutogradFunction(torch.autograd.Function):
dx
=
grad_output
dx
=
grad_output
grad_params
=
[
None
for
_
in
range
(
len
(
basic_ops
))]
grad_params
=
[
None
for
_
in
range
(
len
(
basic_ops
))]
grad_extra_inputs
=
[
None
for
_
in
range
(
len
(
basic_ops
))]
grad_extra_inputs
=
[
None
for
_
in
range
(
len
(
basic_ops
))]
for
op
,
basic_op_idxs
in
backward_ops
:
for
op
,
basic_op_idxs
in
reversed
(
backward_ops
)
:
# Stop if no more gradients are required
# Stop if no more gradients are required
if
all
(
not
basic_op_ctxs
[
idx
].
requires_grad
for
idx
in
basic_op_idxs
):
if
all
(
not
basic_op_ctxs
[
idx
].
requires_grad
for
idx
in
basic_op_idxs
):
...
@@ -315,6 +308,10 @@ class OperationFuser:
...
@@ -315,6 +308,10 @@ class OperationFuser:
"""
"""
# Functions to perform operation fusion
forward_fusion_functions
:
list
[
OperationFusionFunction
]
=
[]
backward_fusion_functions
:
list
[
OperationFusionFunction
]
=
[]
def
__init__
(
def
__init__
(
self
,
self
,
ops
:
list
[
FusibleOperation
],
ops
:
list
[
FusibleOperation
],
...
@@ -334,7 +331,7 @@ class OperationFuser:
...
@@ -334,7 +331,7 @@ class OperationFuser:
self
.
_basic_op_num_extra_inputs
:
list
[
int
]
=
list
(
op
.
num_extra_inputs
for
op
in
basic_ops
)
self
.
_basic_op_num_extra_inputs
:
list
[
int
]
=
list
(
op
.
num_extra_inputs
for
op
in
basic_ops
)
self
.
num_extra_inputs
:
int
=
sum
(
self
.
_basic_op_num_extra_inputs
)
self
.
num_extra_inputs
:
int
=
sum
(
self
.
_basic_op_num_extra_inputs
)
# Ops for forward and backward pass, will be populated in fuse_ops
# Ops for forward and backward pass, will be populated in
maybe_
fuse_ops
self
.
_forward_ops
:
list
[
tuple
[
FusibleOperation
,
list
[
int
]]]
self
.
_forward_ops
:
list
[
tuple
[
FusibleOperation
,
list
[
int
]]]
self
.
_backward_ops
:
list
[
tuple
[
FusibleOperation
,
list
[
int
]]]
self
.
_backward_ops
:
list
[
tuple
[
FusibleOperation
,
list
[
int
]]]
...
@@ -349,31 +346,48 @@ class OperationFuser:
...
@@ -349,31 +346,48 @@ class OperationFuser:
self
.
_flat_basic_op_params
=
sum
(
self
.
_basic_op_params
,
[])
self
.
_flat_basic_op_params
=
sum
(
self
.
_basic_op_params
,
[])
@
classmethod
@
classmethod
def
_fuse_forward_ops
(
def
_fuse_ops
(
cls
,
ops
:
list
[
tuple
[
FusibleOperation
,
list
[
int
]]],
recipe
:
Optional
[
Recipe
],
# pylint: disable=unused-argument
)
->
list
[
tuple
[
FusibleOperation
,
list
[
int
]]]:
"""Attempt to fuse operations in forward pass"""
ops
=
fuse_userbuffers_forward_linear
(
ops
)
ops
=
fuse_forward_linear_bias_add
(
ops
)
ops
=
fuse_forward_linear_bias_activation
(
ops
)
ops
=
fuse_forward_linear_scale_add
(
ops
)
return
ops
@
classmethod
def
_fuse_backward_ops
(
cls
,
cls
,
ops
:
list
[
tuple
[
FusibleOperation
,
list
[
int
]]],
basic_ops
:
Sequence
[
BasicOperation
],
fusion_funcs
:
Iterable
[
OperationFusionFunction
],
recipe
:
Optional
[
Recipe
],
recipe
:
Optional
[
Recipe
],
)
->
list
[
tuple
[
FusibleOperation
,
list
[
int
]]]:
)
->
list
[
tuple
[
FusibleOperation
,
list
[
int
]]]:
"""Attempt to fuse operations in backward pass"""
"""Apply operation fusions"""
ops
=
fuse_userbuffers_backward_linear
(
ops
)
ops
=
fuse_backward_linear_add
(
ops
)
# Apply op fusions
ops
=
fuse_backward_linear_scale
(
ops
)
fused_ops
=
list
(
basic_ops
)
ops
=
fuse_backward_activation_bias
(
ops
,
recipe
)
for
func
in
fusion_funcs
:
ops
=
fuse_backward_add_rmsnorm
(
ops
)
fused_ops
=
func
(
fused_ops
,
recipe
=
recipe
)
return
ops
def
raise_mismatch_error
()
->
None
:
"""Throw error indicating invalid op fusion"""
raise
RuntimeError
(
"Found mismatch after fusing operations "
f
"(basic_ops=
{
[
o
.
__class__
.
__name__
for
o
in
basic_ops
]
}
, "
f
"fused_ops=
{
[
o
.
__class__
.
__name__
for
o
in
fused_ops
]
}
)"
)
# Determine basic op indices corresponding to each op
out
=
[]
idx
=
0
for
op
in
fused_ops
:
if
isinstance
(
op
,
FusedOperation
):
idxs
=
[]
for
basic_op
in
op
.
basic_ops
:
if
basic_op
is
not
basic_ops
[
idx
]:
raise_mismatch_error
()
idxs
.
append
(
idx
)
idx
+=
1
out
.
append
((
op
,
idxs
))
else
:
if
op
is
not
basic_ops
[
idx
]:
raise_mismatch_error
()
out
.
append
((
op
,
[
idx
]))
idx
+=
1
if
idx
!=
len
(
basic_ops
):
raise_mismatch_error
()
return
out
def
maybe_fuse_ops
(
def
maybe_fuse_ops
(
self
,
self
,
...
@@ -424,12 +438,16 @@ class OperationFuser:
...
@@ -424,12 +438,16 @@ class OperationFuser:
op
.
pre_first_fuser_forward
()
op
.
pre_first_fuser_forward
()
# Prepare basic op lists for fusions
# Prepare basic op lists for fusions
forward_ops
=
[(
op
,
[
idx
])
for
idx
,
op
in
enumerate
(
self
.
_basic_ops
)]
self
.
_forward_ops
=
OperationFuser
.
_fuse_ops
(
backward_ops
=
list
(
reversed
(
forward_ops
[
first_op_requiring_backward
:]))
self
.
_basic_ops
,
OperationFuser
.
forward_fusion_functions
,
# Fuse ops
recipe
=
recipe
,
self
.
_forward_ops
=
self
.
_fuse_forward_ops
(
forward_ops
,
recipe
)
)
self
.
_backward_ops
=
self
.
_fuse_backward_ops
(
backward_ops
,
recipe
)
self
.
_backward_ops
=
OperationFuser
.
_fuse_ops
(
self
.
_basic_ops
,
OperationFuser
.
backward_fusion_functions
,
recipe
=
recipe
,
)
# Save current fusion params
# Save current fusion params
self
.
recipe_type
,
self
.
first_op_requiring_backward
=
fusion_params
self
.
recipe_type
,
self
.
first_op_requiring_backward
=
fusion_params
...
@@ -491,3 +509,55 @@ class OperationFuser:
...
@@ -491,3 +509,55 @@ class OperationFuser:
*
extra_inputs
,
*
extra_inputs
,
)
)
return
forward_func
(
*
args
)
return
forward_func
(
*
args
)
def
register_forward_fusion
(
op_fusion_func
:
OperationFusionFunction
,
prepend
:
bool
=
False
,
)
->
None
:
"""Register function to perform operation fusion for forward pass.
The fusion function should have the following signature:
func(ops, *, recipe) -> updated ops
Parameters
----------
op_fusion_func: function
Function that takes a list of operations and may substitute
them with fused operations.
prepend: bool, default = ``False``
Whether the operation fuser should apply this fusion function
first. The default is to apply it last.
"""
if
prepend
:
OperationFuser
.
forward_fusion_functions
.
insert
(
0
,
op_fusion_func
)
else
:
OperationFuser
.
forward_fusion_functions
.
append
(
op_fusion_func
)
def
register_backward_fusion
(
op_fusion_func
:
OperationFusionFunction
,
prepend
:
bool
=
False
,
)
->
None
:
"""Register function to perform operation fusion for backward pass.
The fusion function should have the following signature:
func(ops, *, recipe) -> updated ops
Parameters
----------
op_fusion_func: function
Function that takes a list of operations and may substitute
them with fused operations.
prepend: bool, default = ``False``
Whether the operation fuser should apply this fusion function
first. The default is to apply it last.
"""
if
prepend
:
OperationFuser
.
backward_fusion_functions
.
insert
(
0
,
op_fusion_func
)
else
:
OperationFuser
.
backward_fusion_functions
.
append
(
op_fusion_func
)
transformer_engine/pytorch/optimizers/fused_sgd.py
View file @
9df0c4a3
...
@@ -123,7 +123,7 @@ class FusedSGD(Optimizer):
...
@@ -123,7 +123,7 @@ class FusedSGD(Optimizer):
self
.
set_grad_none
=
set_grad_none
self
.
set_grad_none
=
set_grad_none
if
self
.
set_grad_none
is
not
None
:
if
self
.
set_grad_none
is
not
None
:
warnings
.
warn
(
warnings
.
warn
(
"set_grad_none kwarg in Fused
Adam
constructor is deprecated. "
"set_grad_none kwarg in Fused
SGD
constructor is deprecated. "
"Use set_to_none kwarg in zero_grad instead."
,
"Use set_to_none kwarg in zero_grad instead."
,
DeprecationWarning
,
DeprecationWarning
,
)
)
...
@@ -147,7 +147,7 @@ class FusedSGD(Optimizer):
...
@@ -147,7 +147,7 @@ class FusedSGD(Optimizer):
if
set_to_none
is
not
None
and
set_to_none
!=
self
.
set_grad_none
:
if
set_to_none
is
not
None
and
set_to_none
!=
self
.
set_grad_none
:
raise
ValueError
(
raise
ValueError
(
f
"Called zero_grad with set_to_none=
{
set_to_none
}
, "
f
"Called zero_grad with set_to_none=
{
set_to_none
}
, "
f
"but Fused
Adam
was initialized with set_grad_none=
{
self
.
set_grad_none
}
"
f
"but Fused
SGD
was initialized with set_grad_none=
{
self
.
set_grad_none
}
"
)
)
set_to_none
=
self
.
set_grad_none
set_to_none
=
self
.
set_grad_none
if
set_to_none
is
None
:
if
set_to_none
is
None
:
...
...
transformer_engine/pytorch/quantized_tensor.py
View file @
9df0c4a3
...
@@ -69,7 +69,9 @@ class QuantizedTensorStorage:
...
@@ -69,7 +69,9 @@ class QuantizedTensorStorage:
f
"
{
self
.
__class__
.
__name__
}
class does not implement get_usages function"
f
"
{
self
.
__class__
.
__name__
}
class does not implement get_usages function"
)
)
def
prepare_for_saving
(
self
)
->
Tuple
[
list
[
Optional
[
torch
.
Tensor
]],
QuantizedTensorStorage
]:
def
prepare_for_saving
(
self
,
)
->
Tuple
[
list
[
Optional
[
torch
.
Tensor
]],
QuantizedTensorStorage
]:
"""Prepare the tensor base for saving for backward"""
"""Prepare the tensor base for saving for backward"""
raise
NotImplementedError
(
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
class does not implement prepare_for_saving function"
f
"
{
self
.
__class__
.
__name__
}
class does not implement prepare_for_saving function"
...
@@ -115,11 +117,18 @@ class QuantizedTensorStorage:
...
@@ -115,11 +117,18 @@ class QuantizedTensorStorage:
warnings
.
warn
(
"Quantizer is being updated, this may affect model behavior"
)
warnings
.
warn
(
"Quantizer is being updated, this may affect model behavior"
)
self
.
_quantizer
=
quantizer
self
.
_quantizer
=
quantizer
def
copy_from_storage
(
self
,
src
:
QuantizedTensorStorage
)
->
None
:
"""Copy data from another QuantizedTensorStorage."""
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
class does not implement copy_from_storage function"
)
def
prepare_for_saving
(
def
prepare_for_saving
(
*
tensors
:
Union
[
torch
.
Tensor
,
QuantizedTensorStorage
],
*
tensors
:
Union
[
torch
.
Tensor
,
QuantizedTensorStorage
],
)
->
Tuple
[
)
->
Tuple
[
list
[
Optional
[
Union
[
torch
.
Tensor
,
torch
.
nn
.
Parameter
]]],
list
[
Optional
[
QuantizedTensorStorage
]]
list
[
Optional
[
Union
[
torch
.
Tensor
,
torch
.
nn
.
Parameter
]]],
list
[
Optional
[
QuantizedTensorStorage
]],
]:
]:
"""Prepare tensors for saving. Needed because save_for_backward accepts only
"""Prepare tensors for saving. Needed because save_for_backward accepts only
torch.Tensor/torch.nn.Parameter types, while we want to be able to save
torch.Tensor/torch.nn.Parameter types, while we want to be able to save
...
@@ -144,7 +153,10 @@ def restore_from_saved(
...
@@ -144,7 +153,10 @@ def restore_from_saved(
return_saved_tensors
:
bool
=
False
,
return_saved_tensors
:
bool
=
False
,
)
->
(
)
->
(
list
[
Optional
[
torch
.
Tensor
|
QuantizedTensorStorage
]]
list
[
Optional
[
torch
.
Tensor
|
QuantizedTensorStorage
]]
|
tuple
[
list
[
Optional
[
torch
.
Tensor
|
QuantizedTensorStorage
]],
list
[
Optional
[
torch
.
Tensor
]]]
|
tuple
[
list
[
Optional
[
torch
.
Tensor
|
QuantizedTensorStorage
]],
list
[
Optional
[
torch
.
Tensor
]],
]
):
):
"""Recombine the tensor data and metadata during backward pass."""
"""Recombine the tensor data and metadata during backward pass."""
tensor_objects
=
[]
tensor_objects
=
[]
...
...
transformer_engine/pytorch/tensor/float8_tensor.py
View file @
9df0c4a3
...
@@ -11,7 +11,11 @@ from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState
...
@@ -11,7 +11,11 @@ from torch.distributed.fsdp._fully_shard._fsdp_common import TrainingState
import
transformer_engine_torch
as
tex
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
DType
as
TE_DType
from
transformer_engine_torch
import
DType
as
TE_DType
from
transformer_engine.common.recipe
import
DelayedScaling
,
Float8CurrentScaling
,
Recipe
from
transformer_engine.common.recipe
import
(
DelayedScaling
,
Float8CurrentScaling
,
Recipe
,
)
from
..utils
import
canonicalize_process_group
,
devices_match
from
..utils
import
canonicalize_process_group
,
devices_match
from
.storage.float8_tensor_storage
import
Float8TensorStorage
,
_FromFloat8Func
from
.storage.float8_tensor_storage
import
Float8TensorStorage
,
_FromFloat8Func
from
..quantized_tensor
import
QuantizedTensor
,
Quantizer
from
..quantized_tensor
import
QuantizedTensor
,
Quantizer
...
@@ -155,6 +159,10 @@ class Float8Quantizer(Quantizer):
...
@@ -155,6 +159,10 @@ class Float8Quantizer(Quantizer):
amin
,
amax
=
tensor
.
aminmax
()
amin
,
amax
=
tensor
.
aminmax
()
self
.
amax
.
copy_
(
torch
.
max
(
-
amin
,
amax
))
self
.
amax
.
copy_
(
torch
.
max
(
-
amin
,
amax
))
def
get_columnwise_shape
(
self
,
rowwise_data_shape
:
Iterable
[
int
])
->
Tuple
[
int
,
...]:
"""Calculate the shape of the columnwise data for Float8 1D blockwise quantization."""
return
[
rowwise_data_shape
[
-
1
]]
+
list
(
rowwise_data_shape
[:
-
1
])
def
create_tensor_from_data
(
def
create_tensor_from_data
(
self
,
self
,
data
:
torch
.
Tensor
,
data
:
torch
.
Tensor
,
...
@@ -409,6 +417,10 @@ class Float8CurrentScalingQuantizer(Quantizer):
...
@@ -409,6 +417,10 @@ class Float8CurrentScalingQuantizer(Quantizer):
quantizer
=
self
,
quantizer
=
self
,
)
)
def
get_columnwise_shape
(
self
,
rowwise_data_shape
:
Iterable
[
int
])
->
Tuple
[
int
,
...]:
"""Calculate the shape of the columnwise data for Float8 1D blockwise quantization."""
return
[
rowwise_data_shape
[
-
1
]]
+
list
(
rowwise_data_shape
[:
-
1
])
def
onnx_quantize
(
self
,
tensor
:
torch
.
Tensor
)
->
QuantizedTensor
:
def
onnx_quantize
(
self
,
tensor
:
torch
.
Tensor
)
->
QuantizedTensor
:
"""Function using primitives with ONNX defined translations."""
"""Function using primitives with ONNX defined translations."""
if
tensor
.
dtype
!=
torch
.
float32
:
if
tensor
.
dtype
!=
torch
.
float32
:
...
@@ -770,7 +782,10 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
...
@@ -770,7 +782,10 @@ class Float8Tensor(Float8TensorStorage, QuantizedTensor):
kwargs
,
kwargs
,
)
)
return
Float8Tensor
.
make_like
(
return
Float8Tensor
.
make_like
(
tensor
,
data
=
func_out
,
data_transpose
=
func_transposed_out
,
shape
=
func_out
.
shape
tensor
,
data
=
func_out
,
data_transpose
=
func_transposed_out
,
shape
=
func_out
.
shape
,
)
)
if
func
==
torch
.
ops
.
aten
.
detach
.
default
:
if
func
==
torch
.
ops
.
aten
.
detach
.
default
:
...
...
transformer_engine/pytorch/tensor/mxfp8_tensor.py
View file @
9df0c4a3
...
@@ -164,6 +164,49 @@ class MXFP8Quantizer(Quantizer):
...
@@ -164,6 +164,49 @@ class MXFP8Quantizer(Quantizer):
# TODO(ksivamani): No calibration needed for mxfp8?
# TODO(ksivamani): No calibration needed for mxfp8?
pass
pass
def
get_scale_shape
(
self
,
shape
:
Iterable
[
int
],
columnwise
:
bool
,
)
->
Tuple
[
int
,
int
]:
"""Calculate the shape of the scaling tensor for MXFP8 1D blockwise quantization.
This method determines the shape of the scaling tensor needed for blockwise quantization,
taking into account the input tensor shape and whether columnwise scaling is used.
Parameters
----------
shape : Iterable[int]
Shape of the input tensor to be quantized
columnwise : bool
Whether to use columnwise scaling (True) or rowwise scaling (False)
Returns
-------
Tuple[int, int]
Shape of the scaling tensor as (outer_dim, inner_dim)
For MXFP8 1D blockwise quantization, blocksize is 32
Swizzle kernel will be performed before GEMM to suit the need of CuBLAS.
CuBLAS doc: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
"""
if
columnwise
:
# Columnwise: scale_inv shape is [prod(shape[:-1]) // BLOCK_SIZE, shape[-1]]
# with padding to multiples of [4, 128]
return
(
round_up_to_nearest_multiple
(
math
.
prod
(
shape
[:
-
1
])
//
MXFP8_BLOCK_SCALING_SIZE
,
4
),
round_up_to_nearest_multiple
(
shape
[
-
1
],
128
),
)
# Rowwise: scale_inv shape is [prod(shape[:-1]), shape[-1] // BLOCK_SIZE]
# with padding to multiples of [128, 4]
return
(
round_up_to_nearest_multiple
(
math
.
prod
(
shape
[:
-
1
]),
128
),
round_up_to_nearest_multiple
(
shape
[
-
1
]
//
MXFP8_BLOCK_SCALING_SIZE
,
4
),
)
def
get_columnwise_shape
(
self
,
rowwise_data_shape
:
Tuple
[
int
,
...])
->
Tuple
[
int
,
...]:
"""Calculate the shape of the columnwise data for MXFP8 1D blockwise quantization."""
return
rowwise_data_shape
def
create_tensor_from_data
(
def
create_tensor_from_data
(
self
,
self
,
data
:
torch
.
Tensor
,
data
:
torch
.
Tensor
,
...
@@ -704,7 +747,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
...
@@ -704,7 +747,7 @@ class MXFP8Tensor(MXFP8TensorStorage, QuantizedTensor):
columnwise_scale_inv
=
columnwise_scale_inv
,
columnwise_scale_inv
=
columnwise_scale_inv
,
fp8_dtype
=
fp8_dtype
,
fp8_dtype
=
fp8_dtype
,
dtype
=
param_dtype
,
dtype
=
param_dtype
,
shape
=
rowwise_data
.
shape
if
rowwise_data
is
not
None
else
columnwise_data
.
shape
,
shape
=
(
rowwise_data
.
shape
if
rowwise_data
is
not
None
else
columnwise_data
.
shape
)
,
quantizer
=
self
.
_quantizer
,
quantizer
=
self
.
_quantizer
,
with_gemm_swizzled_scales
=
False
,
with_gemm_swizzled_scales
=
False
,
)
)
...
...
transformer_engine/pytorch/tensor/nvfp4_tensor.py
View file @
9df0c4a3
...
@@ -341,7 +341,10 @@ class NVFP4Quantizer(Quantizer):
...
@@ -341,7 +341,10 @@ class NVFP4Quantizer(Quantizer):
)
)
columnwise_scale_shape
=
self
.
get_scale_shape
(
shape
,
columnwise
=
True
)
columnwise_scale_shape
=
self
.
get_scale_shape
(
shape
,
columnwise
=
True
)
columnwise_scale_inv
=
torch
.
empty
(
columnwise_scale_inv
=
torch
.
empty
(
columnwise_scale_shape
,
dtype
=
torch
.
uint8
,
device
=
device
,
pin_memory
=
pin_memory
columnwise_scale_shape
,
dtype
=
torch
.
uint8
,
device
=
device
,
pin_memory
=
pin_memory
,
)
)
amax_columnwise
=
torch
.
zeros
(
amax_columnwise
=
torch
.
zeros
(
1
,
dtype
=
torch
.
float32
,
device
=
device
,
pin_memory
=
pin_memory
1
,
dtype
=
torch
.
float32
,
device
=
device
,
pin_memory
=
pin_memory
...
...
transformer_engine/pytorch/tensor/storage/__init__.py
View file @
9df0c4a3
...
@@ -7,3 +7,4 @@ from .float8_tensor_storage import Float8TensorStorage # noqa: F401
...
@@ -7,3 +7,4 @@ from .float8_tensor_storage import Float8TensorStorage # noqa: F401
from
.mxfp8_tensor_storage
import
MXFP8TensorStorage
# noqa: F401
from
.mxfp8_tensor_storage
import
MXFP8TensorStorage
# noqa: F401
from
.float8_blockwise_tensor_storage
import
Float8BlockwiseQTensorStorage
# noqa: F401
from
.float8_blockwise_tensor_storage
import
Float8BlockwiseQTensorStorage
# noqa: F401
from
.nvfp4_tensor_storage
import
NVFP4TensorStorage
# noqa: F401
from
.nvfp4_tensor_storage
import
NVFP4TensorStorage
# noqa: F401
from
.grouped_tensor
import
GroupedTensor
# noqa: F401
transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py
View file @
9df0c4a3
...
@@ -74,6 +74,24 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
...
@@ -74,6 +74,24 @@ class Float8BlockwiseQTensorStorage(QuantizedTensorStorage):
if
t
is
not
None
:
if
t
is
not
None
:
t
.
data
=
_empty_tensor
()
t
.
data
=
_empty_tensor
()
def
copy_from_storage
(
self
,
src
:
QuantizedTensorStorage
)
->
None
:
"""Copy data buffers from another Float8BlockwiseQTensorStorage."""
if
not
isinstance
(
src
,
Float8BlockwiseQTensorStorage
):
raise
TypeError
(
"copy_from_storage expects Float8BlockwiseQTensorStorage"
)
if
self
.
_fp8_dtype
!=
src
.
_fp8_dtype
:
raise
RuntimeError
(
"FP8 dtype mismatch in copy_from_storage"
)
if
self
.
_is_2D_scaled
!=
src
.
_is_2D_scaled
:
raise
RuntimeError
(
"Scale layout mismatch in copy_from_storage"
)
def
_copy_optional
(
dst
:
Optional
[
torch
.
Tensor
],
src_tensor
:
Optional
[
torch
.
Tensor
]):
if
dst
is
not
None
and
src_tensor
is
not
None
:
dst
.
copy_
(
src_tensor
)
_copy_optional
(
self
.
_rowwise_data
,
src
.
_rowwise_data
)
_copy_optional
(
self
.
_columnwise_data
,
src
.
_columnwise_data
)
_copy_optional
(
self
.
_rowwise_scale_inv
,
src
.
_rowwise_scale_inv
)
_copy_optional
(
self
.
_columnwise_scale_inv
,
src
.
_columnwise_scale_inv
)
def
get_metadata
(
self
)
->
Dict
[
str
,
Any
]:
def
get_metadata
(
self
)
->
Dict
[
str
,
Any
]:
"""Get this tensor's metadata."""
"""Get this tensor's metadata."""
return
{
return
{
...
...
transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py
View file @
9df0c4a3
...
@@ -104,6 +104,24 @@ class Float8TensorStorage(QuantizedTensorStorage):
...
@@ -104,6 +104,24 @@ class Float8TensorStorage(QuantizedTensorStorage):
t
.
data
=
_empty_tensor
()
t
.
data
=
_empty_tensor
()
self
.
_transpose_invalid
=
True
self
.
_transpose_invalid
=
True
def
copy_from_storage
(
self
,
src
:
QuantizedTensorStorage
)
->
None
:
"""Copy data buffers from another Float8TensorStorage."""
if
not
isinstance
(
src
,
Float8TensorStorage
):
raise
TypeError
(
"copy_from_storage expects Float8TensorStorage"
)
if
self
.
_fp8_dtype
!=
src
.
_fp8_dtype
:
raise
RuntimeError
(
"FP8 dtype mismatch in copy_from_storage"
)
def
_copy_optional
(
dst
:
Optional
[
torch
.
Tensor
],
src_tensor
:
Optional
[
torch
.
Tensor
],
):
if
dst
is
not
None
and
src_tensor
is
not
None
:
dst
.
copy_
(
src_tensor
)
_copy_optional
(
self
.
_data
,
src
.
_data
)
_copy_optional
(
self
.
_transpose
,
src
.
_transpose
)
_copy_optional
(
self
.
_scale_inv
,
src
.
_scale_inv
)
def
get_metadata
(
self
)
->
Dict
[
str
,
Any
]:
def
get_metadata
(
self
)
->
Dict
[
str
,
Any
]:
"""Get this tensor's metadata."""
"""Get this tensor's metadata."""
return
{
return
{
...
...
transformer_engine/pytorch/tensor/storage/grouped_tensor.py
0 → 100644
View file @
9df0c4a3
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
"""Grouped tensor class for handling collections of tensors with different shapes"""
from
__future__
import
annotations
from
typing
import
Optional
,
Tuple
,
List
,
Union
import
math
import
torch
from
...quantized_tensor
import
QuantizedTensorStorage
,
Quantizer
from
..mxfp8_tensor
import
MXFP8Tensor
from
..nvfp4_tensor
import
NVFP4Tensor
from
..float8_tensor
import
Float8Tensor
from
..float8_blockwise_tensor
import
Float8BlockwiseQTensor
from
.float8_tensor_storage
import
Float8TensorStorage
from
.mxfp8_tensor_storage
import
MXFP8TensorStorage
from
.float8_blockwise_tensor_storage
import
Float8BlockwiseQTensorStorage
from
.nvfp4_tensor_storage
import
NVFP4TensorStorage
class
GroupedTensor
:
"""
EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE.
Grouped tensor is a collection of tensors with different shapes but the same dtype and scaling mode.
Shape Representation:
- logical_shape: 2D shape representing the conceptual layout, i.e. the shape when member tensors
are flattened to 2D and stacked together (REQUIRED)
+ When all_same_shape(): [num_tensors * M, N] where each tensor is (M, N)
+ When varying_first_dim(): [~sum_of_first_dims, N] where N is common
+ When varying_last_dim(): [M, ~sum_of_last_dims] where M is common
+ When varying_both_dims(): [1, total_elements] (fully flattened)
- first_dims and last_dims are OPTIONAL (None if dimension is uniform)
+ None first_dims: all tensors have the same first dimension
+ None last_dims: all tensors have the same last dimension
+ Both None: all tensors have identical shapes
+ Both set: each tensor has unique shape (first_dims[i], last_dims[i])
Data Layout:
- ALL data fields are stored as 1D flattened arrays (data, columnwise_data, scale_inv, etc.)
- logical_shape provides the conceptual 2D interpretation
- All data is stored on device in contiguous layout
Note: This structure is used only for combined storage of multiple tensors with the same dtype and scaling mode.
"""
def
__init__
(
self
,
num_tensors
:
int
,
shape
:
List
[
Tuple
[
int
,
int
]],
quantizer
:
Optional
[
Quantizer
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
data
:
Optional
[
torch
.
Tensor
]
=
None
,
columnwise_data
:
Optional
[
torch
.
Tensor
]
=
None
,
scale_inv
:
Optional
[
torch
.
Tensor
]
=
None
,
columnwise_scale_inv
:
Optional
[
torch
.
Tensor
]
=
None
,
amax
:
Optional
[
torch
.
Tensor
]
=
None
,
columnwise_amax
:
Optional
[
torch
.
Tensor
]
=
None
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
first_dims
:
Optional
[
torch
.
Tensor
]
=
None
,
last_dims
:
Optional
[
torch
.
Tensor
]
=
None
,
tensor_offsets
:
Optional
[
torch
.
Tensor
]
=
None
,
offsets
:
Optional
[
List
[
int
]]
=
None
,
scale_inv_offsets
:
Optional
[
List
[
int
]]
=
None
,
columnwise_scale_inv_offsets
:
Optional
[
List
[
int
]]
=
None
,
logical_shape
:
Optional
[
Tuple
[
int
,
int
]]
=
None
,
)
->
None
:
"""
Initialize a GroupedTensor.
Args:
num_tensors: Number of tensors in the group
shape: 2D shape of each tensor (len num_tensors)
quantizer: Quantizer for the grouped tensor
data: Row-wise data buffer (1D flattened)
columnwise_data: Column-wise data buffer (1D flattened)
scale_inv: Row-wise scale inverse buffer
columnwise_scale_inv: Column-wise scale inverse buffer
amax: Row-wise amax buffer
columnwise_amax: Column-wise amax buffer
scale: Scale buffer (for FP8-DS only)
first_dims: Device tensor of int64 array of length num_tensors (or None if uniform)
last_dims: Device tensor of int64 array of length num_tensors (or None if uniform)
tensor_offsets: Device tensor of int64 array of length num_tensors (or None if uniform)
offsets: Vector of integer offsets for each tensor.
logical_shape: 2D tuple representing conceptual shape
"""
self
.
num_tensors
=
num_tensors
self
.
quantizer
=
quantizer
self
.
shape
=
shape
self
.
dtype
=
(
dtype
if
dtype
is
not
None
else
torch
.
float32
)
# Default to float32 if not provided
# Data buffers
self
.
data
=
data
self
.
columnwise_data
=
columnwise_data
self
.
scale_inv
=
scale_inv
self
.
columnwise_scale_inv
=
columnwise_scale_inv
self
.
amax
=
amax
self
.
columnwise_amax
=
columnwise_amax
self
.
scale
=
scale
# For convenient indexing for python GroupedTensor API.
self
.
scale_inv_offsets
=
scale_inv_offsets
self
.
columnwise_scale_inv_offsets
=
columnwise_scale_inv_offsets
# Shape information (OPTIONAL - None if dimension is uniform across all tensors)
# first_dims[i] = first dimension of tensor i (None if all tensors have same first dim)
# last_dims[i] = last dimension of tensor i (None if all tensors have same last dim)
self
.
first_dims
=
(
first_dims
# Device pointer to int64_t array of length num_tensors (or None)
)
self
.
last_dims
=
(
last_dims
# Device pointer to int64_t array of length num_tensors (or None)
)
# Offsets for indexing into contiguous 1D layout (OPTIONAL - not needed if all_same_shape())
# tensor_offsets[i] = element offset to start of tensor i (cumulative sum of numel for tensors 0..i-1)
# Usage: tensor_i_ptr = data.data_ptr() + tensor_offsets[i] * element_size
# If None and all_same_shape(): offset[i] = i * M * N (where M, N are common dimensions)
self
.
tensor_offsets
=
(
tensor_offsets
# Device pointer to int64_t array of length num_tensors (or None)
)
self
.
offsets
=
offsets
# Vector of integer offsets for each tensor.
# Logical shape: conceptual 2D shape of the grouped data (REQUIRED)
# Represents how the 1D flattened data should be interpreted as 2D
# Always 2D with positive dimensions
self
.
logical_shape
=
logical_shape
if
logical_shape
is
not
None
else
(
0
,
0
)
# Hold a reference to the quantized tensors that occupy same storage as the GroupedTensor.
# Used as a convenience.
self
.
quantized_tensors
=
None
def
has_data
(
self
)
->
bool
:
"""
Check if the tensor has row-wise data.
Returns:
True if data buffer is initialized, False otherwise
"""
return
self
.
data
is
not
None
def
has_columnwise_data
(
self
)
->
bool
:
"""
Check if the tensor has column-wise data.
Returns:
True if columnwise_data buffer is initialized, False otherwise
"""
return
self
.
columnwise_data
is
not
None
def
all_same_first_dim
(
self
)
->
bool
:
"""
Check if all tensors in the group have the same first dimension.
Returns:
True if first dimension is uniform across all tensors
"""
return
self
.
first_dims
is
None
def
all_same_last_dim
(
self
)
->
bool
:
"""
Check if all tensors in the group have the same last dimension.
Returns:
True if last dimension is uniform across all tensors
"""
return
self
.
last_dims
is
None
def
all_same_shape
(
self
)
->
bool
:
"""
Check if all tensors in the group have identical shapes.
Returns:
True if all tensors have the same shape
"""
return
self
.
first_dims
is
None
and
self
.
last_dims
is
None
def
varying_both_dims
(
self
)
->
bool
:
"""
Check if both dimensions vary across tensors.
Returns:
True if both first and last dimensions vary
"""
return
self
.
first_dims
is
not
None
and
self
.
last_dims
is
not
None
def
get_common_first_dim
(
self
)
->
int
:
"""
Get the common first dimension when all tensors share it.
Returns:
The common first dimension
Raises:
RuntimeError: If first dimension varies across tensors or logical_shape is not 2D
"""
if
not
self
.
all_same_first_dim
():
raise
RuntimeError
(
"First dim varies across tensors"
)
if
len
(
self
.
logical_shape
)
!=
2
:
raise
RuntimeError
(
"Logical shape must be 2D"
)
if
self
.
all_same_shape
():
# When both dims are uniform: logical_shape = [num_tensors * M, N]
return
self
.
logical_shape
[
0
]
//
self
.
num_tensors
# When varying last dims but not first dim: logical_shape = [M, sum_of_last_dims]
return
self
.
logical_shape
[
0
]
def
get_common_last_dim
(
self
)
->
int
:
"""
Get the common last dimension when all tensors share it.
Returns:
The common last dimension
Raises:
RuntimeError: If last dimension varies across tensors or logical_shape is not 2D
"""
if
not
self
.
all_same_last_dim
():
raise
RuntimeError
(
"Last dim varies across tensors"
)
if
len
(
self
.
logical_shape
)
!=
2
:
raise
RuntimeError
(
"Logical shape must be 2D"
)
# For both uniform and varying first dim cases: logical_shape[1] is the common last dim
return
self
.
logical_shape
[
1
]
def
get_dtype
(
self
)
->
torch
.
dtype
:
"""
Get the high precision data type of the tensor.
Returns:
The high precision dtype of the data buffer
"""
return
self
.
dtype
def
clear
(
self
)
->
None
:
"""
Reset tensor data and clear all buffers.
"""
self
.
data
=
None
self
.
columnwise_data
=
None
self
.
scale_inv
=
None
self
.
columnwise_scale_inv
=
None
self
.
amax
=
None
self
.
columnwise_amax
=
None
self
.
scale
=
None
self
.
first_dims
=
None
self
.
last_dims
=
None
self
.
tensor_offsets
=
None
self
.
logical_shape
=
(
0
,
0
)
self
.
num_tensors
=
0
self
.
quantizer
=
None
self
.
quantized_tensors
=
None
self
.
offsets
=
None
self
.
scale_inv_offsets
=
None
self
.
columnwise_scale_inv_offsets
=
None
def
__repr__
(
self
)
->
str
:
"""String representation of the GroupedTensor."""
return
(
f
"GroupedTensor(num_tensors=
{
self
.
num_tensors
}
, "
f
"shape=
{
self
.
shape
}
, "
f
"logical_shape=
{
self
.
logical_shape
}
, "
f
"dtype=
{
self
.
get_dtype
()
}
)"
)
def
__str__
(
self
)
->
str
:
"""User-friendly string representation."""
shape_info
=
[]
if
self
.
all_same_shape
():
shape_info
.
append
(
"uniform shape"
)
else
:
if
not
self
.
all_same_first_dim
():
shape_info
.
append
(
"varying first dim"
)
if
not
self
.
all_same_last_dim
():
shape_info
.
append
(
"varying last dim"
)
return
(
f
"GroupedTensor with
{
self
.
num_tensors
}
tensors "
f
"(
{
', '
.
join
(
shape_info
)
if
shape_info
else
'uniform'
}
), "
f
"logical_shape=
{
self
.
logical_shape
}
, "
f
"dtype=
{
self
.
get_dtype
()
}
"
)
@
staticmethod
def
make_grouped_tensor_with_shapes
(
num_tensors
:
int
,
shape
:
List
[
Tuple
[
int
,
int
]],
quantizer
:
Optional
[
Quantizer
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
)
->
GroupedTensor
:
"""
Create a GroupedTensor for storing multiple weight tensors of the same shape.
Args:
num_tensors: Number of tensors
shape: 2D shape of each tensor (len num_tensors)
quantizer: Quantizer for each tensor
device: Device to allocate tensors on, defaults to current cuda device
dtype: Data type of the tensor (for high precision case)
Returns:
A GroupedTensor.
"""
# First dim
first_dim_list
=
[
s
[
0
]
for
s
in
shape
]
uniform_first_dim
=
all
(
first_dim_list
[
0
]
==
x
for
x
in
first_dim_list
)
logical_first_dim
=
sum
(
first_dim_list
)
if
uniform_first_dim
:
first_dims
=
None
else
:
first_dims
=
torch
.
tensor
([
s
[
0
]
for
s
in
shape
],
dtype
=
torch
.
int64
,
device
=
device
)
# Last dim
last_dim_list
=
[
s
[
1
]
for
s
in
shape
]
logical_last_dim
=
last_dim_list
[
0
]
assert
all
(
logical_last_dim
==
x
for
x
in
last_dim_list
),
"Last dims should be uniform"
return
GroupedTensor
.
make_grouped_tensor
(
num_tensors
=
num_tensors
,
first_dims
=
first_dims
,
last_dims
=
None
,
logical_first_dim
=
logical_first_dim
,
logical_last_dim
=
logical_last_dim
,
quantizer
=
quantizer
,
device
=
device
,
dtype
=
dtype
,
)
@
staticmethod
def
make_grouped_tensor
(
num_tensors
:
int
,
first_dims
:
Optional
[
torch
.
Tensor
],
last_dims
:
Optional
[
torch
.
Tensor
],
logical_first_dim
:
int
,
logical_last_dim
:
int
,
quantizer
:
Optional
[
Quantizer
]
=
None
,
device
:
Optional
[
torch
.
device
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
)
->
GroupedTensor
:
"""
Create a GroupedTensor for storing multiple weight tensors of the same shape.
Args:
num_tensors: Number of tensors
first_dims: Device tensor of int64 array of length num_tensors (or None if uniform)
last_dims: Device tensor of int64 array of length num_tensors (or None if uniform)
logical_first_dim: Logical first dimension
logical_last_dim: Logical last dimension
quantizer: Quantizer for each tensor
Used to figure out the recipe and what to allocate.
device: Device to allocate tensors on, defaults to current cuda device
dtype: Data type of the tensor (for high precision case)
Returns:
A GroupedTensor.
"""
# Set device
if
device
is
None
:
device
=
torch
.
cuda
.
current_device
()
# Shape patterns and validation.
all_same_first
=
first_dims
is
None
all_same_last
=
last_dims
is
None
assert
all_same_last
,
"Last dim must be uniform for GroupedTensor"
assert
logical_first_dim
>
0
,
"Logical first dim must be positive for GroupedTensor"
assert
logical_last_dim
>
0
,
"Logical last dim must be positive for GroupedTensor"
# assert (
# logical_first_dim % 128 == 0
# ), "Logical first dim must be divisible by 128"
# assert logical_last_dim % 128 == 0, "Logical last dim must be divisible by 128"
# Calculate tensor offsets (cumulative element offsets)
tensor_offsets
=
None
offsets
=
None
shape
=
[]
if
not
all_same_first
:
# Need explicit offsets for non-uniform shapes
# Offsets are based on number of elements and not pointers.
# Kernels need to calculate precise pointers based on size of elements.
# TODO(ksivaman): Single kernel + remove the host offset calculation.
tensor_offsets
=
torch
.
cat
(
[
torch
.
zeros
(
1
,
device
=
first_dims
.
device
,
dtype
=
first_dims
.
dtype
),
torch
.
cumsum
(
first_dims
*
logical_last_dim
,
dim
=
0
),
]
)
offsets
=
tensor_offsets
.
tolist
()
first_dims_list
=
first_dims
.
tolist
()
for
i
in
range
(
num_tensors
):
shape
.
append
((
first_dims_list
[
i
],
logical_last_dim
))
else
:
offsets
=
[
i
*
logical_first_dim
*
logical_last_dim
//
num_tensors
for
i
in
range
(
num_tensors
+
1
)
]
for
i
in
range
(
num_tensors
):
shape
.
append
((
logical_first_dim
//
num_tensors
,
logical_last_dim
))
# Calculate logical shape based
logical_shape
=
(
logical_first_dim
,
logical_last_dim
)
no_quantization
=
quantizer
is
None
rowwise_usage
=
quantizer
.
rowwise_usage
if
not
no_quantization
else
True
columnwise_usage
=
quantizer
.
columnwise_usage
if
not
no_quantization
else
False
# Calculate total elements across all tensors
total_elements
=
logical_first_dim
*
logical_last_dim
data
=
None
columnwise_data
=
None
scale_inv
=
None
columnwise_scale_inv
=
None
amax
=
None
columnwise_amax
=
None
scale
=
None
scale_inv_offsets
=
None
columnwise_scale_inv_offsets
=
None
if
no_quantization
:
assert
dtype
is
not
None
,
"dtype must be provided for unquantized GroupedTensor"
if
rowwise_usage
:
# Allocate rowwise data buffer (1D flattened, uint8)
data
=
torch
.
empty
(
total_elements
,
dtype
=
dtype
,
device
=
device
)
if
columnwise_usage
:
# Allocate columnwise data buffer (1D flattened, uint8)
columnwise_data
=
torch
.
empty
(
total_elements
,
dtype
=
dtype
,
device
=
device
)
elif
quantizer
.
_get_compatible_recipe
().
mxfp8
():
if
rowwise_usage
:
# Allocate rowwise data buffer (1D flattened, uint8)
data
=
torch
.
empty
(
total_elements
,
dtype
=
torch
.
uint8
,
device
=
device
)
# Scale inverse buffer for MXFP8 - complex shape based on block scaling
# For grouped tensors, we need to calculate scale_inv size for all tensors
total_scale_elements
=
0
scale_inv_offsets
=
[
0
]
for
i
,
s
in
enumerate
(
shape
):
scale_inv_shape
=
quantizer
.
get_scale_shape
(
s
,
False
)
scale_elements
=
math
.
prod
(
scale_inv_shape
)
total_scale_elements
+=
scale_elements
if
i
<
num_tensors
-
1
:
scale_inv_offsets
.
append
(
total_scale_elements
)
scale_inv
=
torch
.
empty
(
total_scale_elements
,
dtype
=
torch
.
uint8
,
device
=
device
)
if
columnwise_usage
:
# Allocate columnwise data buffer (1D flattened, uint8)
columnwise_data
=
torch
.
empty
(
total_elements
,
dtype
=
torch
.
uint8
,
device
=
device
)
# Columnwise scale inverse buffer
total_columnwise_scale_elements
=
0
columnwise_scale_inv_offsets
=
[
0
]
for
i
,
s
in
enumerate
(
shape
):
scale_inv_shape
=
quantizer
.
get_scale_shape
(
s
,
False
)
columnwise_scale_elements
=
math
.
prod
(
scale_inv_shape
)
total_columnwise_scale_elements
+=
columnwise_scale_elements
if
i
<
num_tensors
-
1
:
columnwise_scale_inv_offsets
.
append
(
total_columnwise_scale_elements
)
columnwise_scale_inv
=
torch
.
empty
(
total_columnwise_scale_elements
,
dtype
=
torch
.
uint8
,
device
=
device
)
elif
quantizer
.
_get_compatible_recipe
().
delayed
():
if
rowwise_usage
:
# Allocate rowwise data buffer (1D flattened, uint8)
data
=
torch
.
empty
(
total_elements
,
dtype
=
torch
.
uint8
,
device
=
device
)
# Scale inverse - one per tensor
scale_inv
=
torch
.
empty
(
num_tensors
,
dtype
=
torch
.
float32
,
device
=
device
)
# One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1
scale_inv_offsets
=
list
(
range
(
num_tensors
))
if
columnwise_usage
:
# Allocate columnwise data buffer (1D flattened, uint8)
columnwise_data
=
torch
.
empty
(
total_elements
,
dtype
=
torch
.
uint8
,
device
=
device
)
# Columnwise scale inverse - one per tensor
columnwise_scale_inv
=
torch
.
empty
(
num_tensors
,
dtype
=
torch
.
float32
,
device
=
device
)
# One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1
columnwise_scale_inv_offsets
=
list
(
range
(
num_tensors
))
# Amax buffer for delayed scaling - one per tensor
amax
=
torch
.
empty
(
num_tensors
,
dtype
=
torch
.
float32
,
device
=
device
)
elif
quantizer
.
_get_compatible_recipe
().
nvfp4
():
if
rowwise_usage
:
# Allocate rowwise data buffer (1D flattened, uint8, but FP4 packs 2 values per byte)
data
=
torch
.
empty
((
total_elements
)
//
2
,
dtype
=
torch
.
uint8
,
device
=
device
)
# Scale inverse buffer for NVFP4 - complex shape based on block scaling
# For simplicity, calculate total scale elements needed
total_scale_elements
=
0
scale_inv_offsets
=
[
0
]
for
i
,
s
in
enumerate
(
shape
):
scale_inv_shape
=
quantizer
.
get_scale_shape
(
s
,
False
)
total_scale_elements
+=
math
.
prod
(
scale_inv_shape
)
if
i
<
num_tensors
-
1
:
scale_inv_offsets
.
append
(
total_scale_elements
)
scale_inv
=
torch
.
empty
(
total_scale_elements
,
dtype
=
torch
.
uint8
,
device
=
device
)
# Amax buffer - one per tensor
amax
=
torch
.
empty
(
num_tensors
,
dtype
=
torch
.
float32
,
device
=
device
)
if
columnwise_usage
:
# Allocate columnwise data buffer (1D flattened, uint8, FP4 packed)
columnwise_data
=
torch
.
empty
(
(
total_elements
)
//
2
,
dtype
=
torch
.
uint8
,
device
=
device
)
# Columnwise scale inverse buffer
total_columnwise_scale_elements
=
0
columnwise_scale_inv_offsets
=
[
0
]
for
i
,
s
in
enumerate
(
shape
):
columnwise_scale_inv_shape
=
quantizer
.
get_scale_shape
(
s
,
True
)
total_columnwise_scale_elements
+=
math
.
prod
(
columnwise_scale_inv_shape
)
if
i
<
num_tensors
-
1
:
columnwise_scale_inv_offsets
.
append
(
total_columnwise_scale_elements
)
columnwise_scale_inv
=
torch
.
empty
(
total_columnwise_scale_elements
,
dtype
=
torch
.
uint8
,
device
=
device
)
# Columnwise amax buffer - one per tensor
columnwise_amax
=
torch
.
empty
(
num_tensors
,
dtype
=
torch
.
float32
,
device
=
device
)
elif
quantizer
.
_get_compatible_recipe
().
float8_block_scaling
():
if
rowwise_usage
:
# Allocate rowwise data buffer (1D flattened, uint8)
data
=
torch
.
empty
(
total_elements
,
dtype
=
torch
.
uint8
,
device
=
device
)
# Scale inverse - size depends on block configuration
# For simplicity, calculate total scale elements needed
total_scale_elements
=
0
scale_inv_offsets
=
[
0
]
for
i
,
s
in
enumerate
(
shape
):
scale_inv_shape
=
quantizer
.
get_scale_shape
(
s
,
False
)
total_scale_elements
+=
math
.
prod
(
scale_inv_shape
)
if
i
<
num_tensors
-
1
:
scale_inv_offsets
.
append
(
total_scale_elements
)
scale_inv
=
torch
.
empty
(
total_scale_elements
,
dtype
=
torch
.
float32
,
device
=
device
)
if
columnwise_usage
:
# Allocate columnwise data buffer (1D flattened, uint8)
columnwise_data
=
torch
.
empty
(
total_elements
,
dtype
=
torch
.
uint8
,
device
=
device
)
# Columnwise scale inverse
total_columnwise_scale_elements
=
0
columnwise_scale_inv_offsets
=
[
0
]
for
i
,
s
in
enumerate
(
shape
):
columnwise_scale_inv_shape
=
quantizer
.
get_scale_shape
(
s
,
True
)
total_columnwise_scale_elements
+=
math
.
prod
(
columnwise_scale_inv_shape
)
if
i
<
num_tensors
-
1
:
columnwise_scale_inv_offsets
.
append
(
total_columnwise_scale_elements
)
columnwise_scale_inv
=
torch
.
empty
(
total_columnwise_scale_elements
,
dtype
=
torch
.
float32
,
device
=
device
)
elif
quantizer
.
_get_compatible_recipe
().
float8_current_scaling
():
# Current scaling - per-tensor scaling computed on the fly
if
rowwise_usage
:
# Allocate rowwise data buffer (1D flattened, uint8)
data
=
torch
.
empty
(
total_elements
,
dtype
=
torch
.
uint8
,
device
=
device
)
# Scale inverse - one per tensor
scale_inv
=
torch
.
empty
(
num_tensors
,
dtype
=
torch
.
float32
,
device
=
device
)
# One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1
scale_inv_offsets
=
list
(
range
(
num_tensors
))
if
columnwise_usage
:
# Allocate columnwise data buffer (1D flattened, uint8)
columnwise_data
=
torch
.
empty
(
total_elements
,
dtype
=
torch
.
uint8
,
device
=
device
)
# Columnwise scale inverse - one per tensor
columnwise_scale_inv
=
torch
.
empty
(
num_tensors
,
dtype
=
torch
.
float32
,
device
=
device
)
# One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1
columnwise_scale_inv_offsets
=
list
(
range
(
num_tensors
))
# Scale and amax buffers for current scaling - one per tensor
scale
=
torch
.
empty
(
num_tensors
,
dtype
=
torch
.
float32
,
device
=
device
)
amax
=
torch
.
empty
(
num_tensors
,
dtype
=
torch
.
float32
,
device
=
device
)
else
:
raise
ValueError
(
f
"Unsupported quantizer for GroupedTensor:
{
quantizer
}
"
)
grouped_tensor
=
GroupedTensor
(
num_tensors
=
num_tensors
,
shape
=
shape
,
dtype
=
dtype
,
quantizer
=
quantizer
,
data
=
data
,
columnwise_data
=
columnwise_data
,
scale_inv
=
scale_inv
,
columnwise_scale_inv
=
columnwise_scale_inv
,
amax
=
amax
,
columnwise_amax
=
columnwise_amax
,
scale
=
scale
,
first_dims
=
first_dims
,
last_dims
=
last_dims
,
tensor_offsets
=
tensor_offsets
,
offsets
=
offsets
,
scale_inv_offsets
=
scale_inv_offsets
,
columnwise_scale_inv_offsets
=
columnwise_scale_inv_offsets
,
logical_shape
=
logical_shape
,
)
grouped_tensor
.
quantized_tensors
=
grouped_tensor
.
split_into_quantized_tensors
()
return
grouped_tensor
def
split_into_quantized_tensors
(
self
,
)
->
List
[
Union
[
QuantizedTensorStorage
,
torch
.
Tensor
]]:
"""
Split the GroupedTensor into a list of `num_tensors`
quantized tensors based on the quantizer. No additional memory allocation is performed,
so the tensors returned are the same as the ones used to create the GroupedTensor.
If quantizer is None, returns normal torch tensors.
If quantizer.internal is True, returns QuantizedTensorStorage.
Otherwise, returns QuantizedTensor.
TODO(ksivaman): Block cases where any dims are varying. This is needed only
to expose the weights as separate parameters.
"""
result
=
[]
no_quantization
=
self
.
quantizer
is
None
# Case 1: No quantization - return regular torch tensors
if
no_quantization
:
for
i
in
range
(
self
.
num_tensors
):
# Get tensor shape
tensor_shape
=
self
.
shape
[
i
]
# Get tensor data slice
if
self
.
offsets
is
not
None
:
start_offset
=
self
.
offsets
[
i
]
numel
=
tensor_shape
[
0
]
*
tensor_shape
[
1
]
end_offset
=
start_offset
+
numel
if
self
.
has_data
():
tensor_data
=
self
.
data
[
start_offset
:
end_offset
].
view
(
tensor_shape
)
result
.
append
(
tensor_data
)
elif
self
.
has_columnwise_data
():
tensor_data
=
self
.
columnwise_data
[
start_offset
:
end_offset
].
view
(
tensor_shape
)
result
.
append
(
tensor_data
)
else
:
raise
RuntimeError
(
"GroupedTensor has no data to split"
)
else
:
# All same shape case
numel
=
tensor_shape
[
0
]
*
tensor_shape
[
1
]
start_offset
=
i
*
numel
end_offset
=
start_offset
+
numel
if
self
.
has_data
():
tensor_data
=
self
.
data
[
start_offset
:
end_offset
].
view
(
tensor_shape
)
result
.
append
(
tensor_data
)
elif
self
.
has_columnwise_data
():
tensor_data
=
self
.
columnwise_data
[
start_offset
:
end_offset
].
view
(
tensor_shape
)
result
.
append
(
tensor_data
)
else
:
raise
RuntimeError
(
"GroupedTensor has no data to split"
)
return
result
# Case 2: Quantized tensors
recipe
=
self
.
quantizer
.
_get_compatible_recipe
()
for
i
in
range
(
self
.
num_tensors
):
# Get tensor shape
tensor_shape
=
self
.
shape
[
i
]
numel
=
tensor_shape
[
0
]
*
tensor_shape
[
1
]
# Get data offsets
if
self
.
offsets
is
not
None
:
data_start
=
self
.
offsets
[
i
]
data_end
=
data_start
+
numel
else
:
# All same shape
data_start
=
i
*
numel
data_end
=
data_start
+
numel
# Special shape handling for NVFP4.
nvfp4
=
self
.
quantizer
.
_get_compatible_recipe
().
nvfp4
()
if
nvfp4
:
data_start
=
data_start
//
2
data_end
=
data_end
//
2
# Extract rowwise and columnwise data
rowwise_data
=
None
columnwise_data
=
None
if
self
.
has_data
():
if
nvfp4
:
rowwise_tensor_shape
=
self
.
quantizer
.
convert_shape_for_fp4
(
tensor_shape
)
else
:
rowwise_tensor_shape
=
tensor_shape
rowwise_data
=
self
.
data
[
data_start
:
data_end
].
view
(
rowwise_tensor_shape
)
if
self
.
has_columnwise_data
():
columnwise_tensor_shape
=
self
.
quantizer
.
get_columnwise_shape
(
tensor_shape
)
if
nvfp4
:
columnwise_tensor_shape
=
self
.
quantizer
.
convert_shape_for_fp4
(
columnwise_tensor_shape
)
columnwise_data
=
self
.
columnwise_data
[
data_start
:
data_end
].
view
(
columnwise_tensor_shape
)
# MXFP8 format
if
recipe
.
mxfp8
():
# Extract scale_inv data
rowwise_scale_inv
=
None
columnwise_scale_inv
=
None
if
self
.
scale_inv
is
not
None
and
self
.
scale_inv_offsets
is
not
None
:
scale_start
=
self
.
scale_inv_offsets
[
i
]
if
i
<
self
.
num_tensors
-
1
:
scale_end
=
self
.
scale_inv_offsets
[
i
+
1
]
else
:
scale_end
=
self
.
scale_inv
.
numel
()
# Calculate expected scale shape for MXFP8
scale_shape
=
self
.
quantizer
.
get_scale_shape
(
tensor_shape
,
False
)
rowwise_scale_inv
=
self
.
scale_inv
[
scale_start
:
scale_end
].
view
(
scale_shape
)
if
(
self
.
columnwise_scale_inv
is
not
None
and
self
.
columnwise_scale_inv_offsets
is
not
None
):
cscale_start
=
self
.
columnwise_scale_inv_offsets
[
i
]
if
i
<
self
.
num_tensors
-
1
:
cscale_end
=
self
.
columnwise_scale_inv_offsets
[
i
+
1
]
else
:
cscale_end
=
self
.
columnwise_scale_inv
.
numel
()
cscale_shape
=
self
.
quantizer
.
get_scale_shape
(
tensor_shape
,
True
)
columnwise_scale_inv
=
self
.
columnwise_scale_inv
[
cscale_start
:
cscale_end
].
view
(
cscale_shape
)
if
self
.
quantizer
.
internal
:
mxfp8_tensor_class
=
MXFP8TensorStorage
else
:
mxfp8_tensor_class
=
MXFP8Tensor
tensor
=
mxfp8_tensor_class
(
shape
=
tensor_shape
,
dtype
=
self
.
dtype
,
rowwise_data
=
rowwise_data
,
rowwise_scale_inv
=
rowwise_scale_inv
,
columnwise_data
=
columnwise_data
,
columnwise_scale_inv
=
columnwise_scale_inv
,
fp8_dtype
=
self
.
quantizer
.
dtype
,
quantizer
=
self
.
quantizer
,
with_gemm_swizzled_scales
=
self
.
quantizer
.
optimize_for_gemm
,
)
result
.
append
(
tensor
)
# Delayed scaling or current scaling (both use Float8TensorStorage)
elif
recipe
.
delayed
()
or
recipe
.
float8_current_scaling
():
# Scale inverse - one per tensor
scale_inv
=
None
if
self
.
scale_inv
is
not
None
:
scale_inv
=
self
.
scale_inv
[
i
:
i
+
1
]
if
self
.
quantizer
.
internal
:
float8_tensor_class
=
Float8TensorStorage
else
:
float8_tensor_class
=
Float8Tensor
tensor
=
float8_tensor_class
(
shape
=
tensor_shape
,
dtype
=
self
.
dtype
,
data
=
rowwise_data
,
fp8_scale_inv
=
scale_inv
,
fp8_dtype
=
self
.
quantizer
.
dtype
,
quantizer
=
self
.
quantizer
,
data_transpose
=
columnwise_data
,
)
result
.
append
(
tensor
)
# Float8 block scaling
elif
recipe
.
float8_block_scaling
():
# Extract scale_inv data
rowwise_scale_inv
=
None
columnwise_scale_inv
=
None
if
self
.
scale_inv
is
not
None
and
self
.
scale_inv_offsets
is
not
None
:
scale_start
=
self
.
scale_inv_offsets
[
i
]
if
i
<
self
.
num_tensors
-
1
:
scale_end
=
self
.
scale_inv_offsets
[
i
+
1
]
else
:
scale_end
=
self
.
scale_inv
.
numel
()
# Get scale shape from quantizer
scale_shape
=
self
.
quantizer
.
get_scale_shape
(
tensor_shape
,
False
)
rowwise_scale_inv
=
self
.
scale_inv
[
scale_start
:
scale_end
].
view
(
scale_shape
)
if
(
self
.
columnwise_scale_inv
is
not
None
and
self
.
columnwise_scale_inv_offsets
is
not
None
):
cscale_start
=
self
.
columnwise_scale_inv_offsets
[
i
]
if
i
<
self
.
num_tensors
-
1
:
cscale_end
=
self
.
columnwise_scale_inv_offsets
[
i
+
1
]
else
:
cscale_end
=
self
.
columnwise_scale_inv
.
numel
()
# Get columnwise scale shape from quantizer
cscale_shape
=
self
.
quantizer
.
get_scale_shape
(
tensor_shape
,
True
)
columnwise_scale_inv
=
self
.
columnwise_scale_inv
[
cscale_start
:
cscale_end
].
view
(
cscale_shape
)
# Compute is_2D_scaled and data_format from quantizer attributes
is_2D_scaled
=
self
.
quantizer
.
block_scaling_dim
==
2
if
self
.
quantizer
.
internal
:
float8_blockwise_q_tensor_class
=
Float8BlockwiseQTensorStorage
else
:
float8_blockwise_q_tensor_class
=
Float8BlockwiseQTensor
tensor
=
float8_blockwise_q_tensor_class
(
shape
=
tensor_shape
,
dtype
=
self
.
dtype
,
rowwise_data
=
rowwise_data
,
rowwise_scale_inv
=
rowwise_scale_inv
,
columnwise_data
=
columnwise_data
,
columnwise_scale_inv
=
columnwise_scale_inv
,
fp8_dtype
=
self
.
quantizer
.
dtype
,
quantizer
=
self
.
quantizer
,
is_2D_scaled
=
is_2D_scaled
,
)
result
.
append
(
tensor
)
# NVFP4 format
elif
recipe
.
nvfp4
():
# Extract scale_inv data
rowwise_scale_inv
=
None
columnwise_scale_inv
=
None
amax_rowwise
=
None
amax_columnwise
=
None
if
self
.
scale_inv
is
not
None
and
self
.
scale_inv_offsets
is
not
None
:
scale_start
=
self
.
scale_inv_offsets
[
i
]
if
i
<
self
.
num_tensors
-
1
:
scale_end
=
self
.
scale_inv_offsets
[
i
+
1
]
else
:
scale_end
=
self
.
scale_inv
.
numel
()
# Get scale shape from quantizer
scale_shape
=
self
.
quantizer
.
get_scale_shape
(
tensor_shape
,
False
)
rowwise_scale_inv
=
self
.
scale_inv
[
scale_start
:
scale_end
].
view
(
scale_shape
)
if
(
self
.
columnwise_scale_inv
is
not
None
and
self
.
columnwise_scale_inv_offsets
is
not
None
):
cscale_start
=
self
.
columnwise_scale_inv_offsets
[
i
]
if
i
<
self
.
num_tensors
-
1
:
cscale_end
=
self
.
columnwise_scale_inv_offsets
[
i
+
1
]
else
:
cscale_end
=
self
.
columnwise_scale_inv
.
numel
()
# Get columnwise scale shape from quantizer
cscale_shape
=
self
.
quantizer
.
get_scale_shape
(
tensor_shape
,
True
)
columnwise_scale_inv
=
self
.
columnwise_scale_inv
[
cscale_start
:
cscale_end
].
view
(
cscale_shape
)
# Extract amax - one per tensor
if
self
.
amax
is
not
None
:
amax_rowwise
=
self
.
amax
[
i
:
i
+
1
]
if
self
.
columnwise_amax
is
not
None
:
amax_columnwise
=
self
.
columnwise_amax
[
i
:
i
+
1
]
if
self
.
quantizer
.
internal
:
nvfp4_tensor_class
=
NVFP4TensorStorage
else
:
nvfp4_tensor_class
=
NVFP4Tensor
tensor
=
nvfp4_tensor_class
(
shape
=
tensor_shape
,
dtype
=
self
.
dtype
,
rowwise_data
=
rowwise_data
,
rowwise_scale_inv
=
rowwise_scale_inv
,
columnwise_data
=
columnwise_data
,
columnwise_scale_inv
=
columnwise_scale_inv
,
amax_rowwise
=
amax_rowwise
,
amax_columnwise
=
amax_columnwise
,
fp4_dtype
=
self
.
quantizer
.
dtype
,
quantizer
=
self
.
quantizer
,
with_gemm_swizzled_scales
=
self
.
quantizer
.
optimize_for_gemm
,
)
result
.
append
(
tensor
)
else
:
raise
ValueError
(
f
"Unsupported quantization recipe:
{
recipe
}
"
)
return
result
@
staticmethod
def
create_and_quantize
(
tensors
:
int
,
quantizer
:
None
|
Quantizer
,
*
,
device
:
Optional
[
torch
.
device
]
=
None
,
dtype
:
Optional
[
torch
.
dtype
]
=
None
,
noop_flag
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
QuantizedTensorStorage
,
...]:
"""
Quantize given tensors into quantized tensors with underlying
storage allocated in a GroupedTensor.
"""
grouped_tensor
=
GroupedTensor
.
make_grouped_tensor_with_shapes
(
num_tensors
=
len
(
tensors
),
shape
=
[
t
.
shape
for
t
in
tensors
],
quantizer
=
quantizer
,
device
=
device
,
dtype
=
dtype
,
)
grouped_tensor
.
quantize
(
tensors
,
noop_flag
=
noop_flag
)
return
grouped_tensor
def
quantize
(
self
,
tensors
:
List
[
torch
.
Tensor
],
noop_flag
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
Tuple
[
QuantizedTensorStorage
,
...]:
"""
Quantize the GroupedTensor inplace.
"""
quantized_tensors
=
self
.
split_into_quantized_tensors
()
for
i
in
range
(
self
.
num_tensors
):
self
.
quantizer
.
update_quantized
(
tensors
[
i
],
quantized_tensors
[
i
],
noop_flag
=
noop_flag
)
return
quantized_tensors
transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py
View file @
9df0c4a3
...
@@ -111,6 +111,24 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
...
@@ -111,6 +111,24 @@ class MXFP8TensorStorage(QuantizedTensorStorage):
if
t
is
not
None
:
if
t
is
not
None
:
t
.
data
=
_empty_tensor
()
t
.
data
=
_empty_tensor
()
def
copy_from_storage
(
self
,
src
:
QuantizedTensorStorage
)
->
None
:
"""Copy data buffers from another MXFP8TensorStorage."""
if
not
isinstance
(
src
,
MXFP8TensorStorage
):
raise
TypeError
(
"copy_from_storage expects MXFP8TensorStorage"
)
if
self
.
_fp8_dtype
!=
src
.
_fp8_dtype
:
raise
RuntimeError
(
"FP8 dtype mismatch in copy_from_storage"
)
if
self
.
_with_gemm_swizzled_scales
!=
src
.
_with_gemm_swizzled_scales
:
raise
RuntimeError
(
"Scale layout mismatch in copy_from_storage"
)
def
_copy_optional
(
dst
:
Optional
[
torch
.
Tensor
],
src_tensor
:
Optional
[
torch
.
Tensor
]):
if
dst
is
not
None
and
src_tensor
is
not
None
:
dst
.
copy_
(
src_tensor
)
_copy_optional
(
self
.
_rowwise_data
,
src
.
_rowwise_data
)
_copy_optional
(
self
.
_columnwise_data
,
src
.
_columnwise_data
)
_copy_optional
(
self
.
_rowwise_scale_inv
,
src
.
_rowwise_scale_inv
)
_copy_optional
(
self
.
_columnwise_scale_inv
,
src
.
_columnwise_scale_inv
)
def
get_metadata
(
self
)
->
Dict
[
str
,
Any
]:
def
get_metadata
(
self
)
->
Dict
[
str
,
Any
]:
"""Get this tensor's metadata."""
"""Get this tensor's metadata."""
return
{
return
{
...
...
transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py
View file @
9df0c4a3
...
@@ -136,6 +136,26 @@ class NVFP4TensorStorage(QuantizedTensorStorage):
...
@@ -136,6 +136,26 @@ class NVFP4TensorStorage(QuantizedTensorStorage):
if
t
is
not
None
:
if
t
is
not
None
:
t
.
data
=
_empty_tensor
()
t
.
data
=
_empty_tensor
()
def
copy_from_storage
(
self
,
src
:
QuantizedTensorStorage
)
->
None
:
"""Copy data buffers from another NVFP4TensorStorage."""
if
not
isinstance
(
src
,
NVFP4TensorStorage
):
raise
TypeError
(
"copy_from_storage expects NVFP4TensorStorage"
)
if
self
.
_fp4_dtype
!=
src
.
_fp4_dtype
:
raise
RuntimeError
(
"FP4 dtype mismatch in copy_from_storage"
)
if
self
.
_with_gemm_swizzled_scales
!=
src
.
_with_gemm_swizzled_scales
:
raise
RuntimeError
(
"Scale layout mismatch in copy_from_storage"
)
def
_copy_optional
(
dst
:
Optional
[
torch
.
Tensor
],
src_tensor
:
Optional
[
torch
.
Tensor
]):
if
dst
is
not
None
and
src_tensor
is
not
None
:
dst
.
copy_
(
src_tensor
)
_copy_optional
(
self
.
_rowwise_data
,
src
.
_rowwise_data
)
_copy_optional
(
self
.
_columnwise_data
,
src
.
_columnwise_data
)
_copy_optional
(
self
.
_rowwise_scale_inv
,
src
.
_rowwise_scale_inv
)
_copy_optional
(
self
.
_columnwise_scale_inv
,
src
.
_columnwise_scale_inv
)
_copy_optional
(
self
.
_amax_rowwise
,
src
.
_amax_rowwise
)
_copy_optional
(
self
.
_amax_columnwise
,
src
.
_amax_columnwise
)
def
get_metadata
(
self
)
->
Dict
[
str
,
Any
]:
def
get_metadata
(
self
)
->
Dict
[
str
,
Any
]:
"""Get this tensor's metadata."""
"""Get this tensor's metadata."""
return
{
return
{
...
...
Prev
1
…
7
8
9
10
11
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment