Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
f8c2af4c
Commit
f8c2af4c
authored
May 21, 2025
by
yuguo
Browse files
Merge commit '
1d903f5e
' of...
Merge commit '
1d903f5e
' of
https://github.com/NVIDIA/TransformerEngine
parents
e92773a3
1d903f5e
Changes
211
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
828 additions
and
217 deletions
+828
-217
transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py
.../pytorch/tensor/_internal/float8_blockwise_tensor_base.py
+114
-9
transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py
...mer_engine/pytorch/tensor/_internal/float8_tensor_base.py
+59
-4
transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py
...rmer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py
+74
-3
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
+138
-58
transformer_engine/pytorch/tensor/float8_tensor.py
transformer_engine/pytorch/tensor/float8_tensor.py
+1
-44
transformer_engine/pytorch/tensor/mxfp8_tensor.py
transformer_engine/pytorch/tensor/mxfp8_tensor.py
+0
-50
transformer_engine/pytorch/tensor/quantized_tensor.py
transformer_engine/pytorch/tensor/quantized_tensor.py
+64
-19
transformer_engine/pytorch/tensor/utils.py
transformer_engine/pytorch/tensor/utils.py
+138
-1
transformer_engine/pytorch/transformer.py
transformer_engine/pytorch/transformer.py
+12
-15
transformer_engine/pytorch/triton/cross_entropy.py
transformer_engine/pytorch/triton/cross_entropy.py
+11
-0
transformer_engine/pytorch/utils.py
transformer_engine/pytorch/utils.py
+217
-14
No files found.
transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py
View file @
f8c2af4c
...
...
@@ -9,14 +9,19 @@ import math
from
typing
import
Optional
,
Dict
,
Any
,
Tuple
import
torch
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
DType
as
TE_DType
from
..quantized_tensor
import
QuantizedTensorBase
from
...constants
import
TE_DType_To_Torch
from
..quantized_tensor
import
Quantizer
from
...utils
import
_empty_tensor
class
Float8BlockwiseQTensorBase
:
class
Float8BlockwiseQTensorBase
(
QuantizedTensorBase
)
:
"""Mixin class that holds data attributes of Float8BlockwiseQTensor.
Float8BlockwiseQTensor inherits from the PyTorch tensor class and this
...
...
@@ -56,6 +61,17 @@ class Float8BlockwiseQTensorBase:
return
instance
def
clear
(
self
):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
for
t
in
(
self
.
_rowwise_data
,
self
.
_columnwise_data
,
self
.
_rowwise_scale_inv
,
self
.
_columnwise_scale_inv
,
):
if
t
is
not
None
:
t
.
data
=
_empty_tensor
()
def
get_metadata
(
self
)
->
Dict
[
str
,
Any
]:
"""Get this tensor's metadata."""
return
{
...
...
@@ -73,14 +89,17 @@ class Float8BlockwiseQTensorBase:
)
->
Tuple
[
list
[
Optional
[
torch
.
Tensor
]],
Float8BlockwiseQTensorBase
]:
"""
Prepare the tensor base for saving for backward
This does not clear the tensors currently, because with PP config
that clears the weight cache between micro-batches. If the rowwise
data is not required for backward, this is a possible memory
pessimization, but is consistent with the other quantized tensor
classes.
"""
tensors
=
[
self
.
_rowwise_data
,
self
.
_columnwise_data
]
tensors
=
[
self
.
_rowwise_data
,
self
.
_columnwise_data
,
self
.
_rowwise_scale_inv
,
self
.
_columnwise_scale_inv
,
]
self
.
_rowwise_data
=
None
self
.
_columnwise_data
=
None
self
.
_rowwise_scale_inv
=
None
self
.
_columnwise_scale_inv
=
None
return
tensors
,
self
def
restore_from_saved
(
...
...
@@ -89,7 +108,9 @@ class Float8BlockwiseQTensorBase:
"""Restore the tensor base data from the saved tensors list."""
self
.
_rowwise_data
=
tensors
[
0
]
self
.
_columnwise_data
=
tensors
[
1
]
return
tensors
[
2
:]
self
.
_rowwise_scale_inv
=
tensors
[
2
]
self
.
_columnwise_scale_inv
=
tensors
[
3
]
return
tensors
[
4
:]
def
get_data_tensors
(
self
):
"""Get this Tensor's data."""
...
...
@@ -232,6 +253,38 @@ class Float8BlockwiseQTensorBase:
reordered
.
append
(
dims
[
0
])
return
torch
.
Size
(
reordered
)
def
_create_columnwise
(
self
):
"""
Update columnwise data and columnwise scale inv. Can only be used when using 2D scaling.
"""
assert
self
.
_is_2D_scaled
,
"Cannot create columnwise data when not using 2D scaling."
rowwise_data
=
self
.
_rowwise_data
if
not
rowwise_data
.
is_contiguous
():
rowwise_data
=
rowwise_data
.
contiguous
()
self
.
_columnwise_data
=
tex
.
fp8_transpose
(
rowwise_data
,
self
.
_fp8_dtype
,
out
=
self
.
_columnwise_data
)
if
self
.
_columnwise_scale_inv
is
None
:
assert
self
.
_quantizer
is
not
None
,
(
"._quantizer of Float8BlockwiseQTensor cannot be None because all the blockwise "
"quantized tensors are supposed to be generated from the quantizer."
)
columnwise_scale_inv_shape
=
self
.
_quantizer
.
get_scale_shape
(
rowwise_data
.
shape
,
True
)
self
.
_columnwise_scale_inv
=
torch
.
empty
(
columnwise_scale_inv_shape
,
dtype
=
self
.
_rowwise_scale_inv
.
dtype
,
device
=
self
.
_rowwise_scale_inv
.
device
,
)
assert
len
(
self
.
_rowwise_scale_inv
.
shape
)
==
2
assert
len
(
self
.
_columnwise_scale_inv
.
shape
)
==
2
rowwise_scale_inv
=
self
.
_rowwise_scale_inv
columnwise_scale_inv
=
rowwise_scale_inv
.
transpose
(
-
2
,
-
1
)
h
=
min
(
self
.
_columnwise_scale_inv
.
shape
[
0
],
columnwise_scale_inv
.
shape
[
0
])
w
=
min
(
self
.
_columnwise_scale_inv
.
shape
[
1
],
columnwise_scale_inv
.
shape
[
1
])
self
.
_columnwise_scale_inv
[
0
:
h
,
0
:
w
].
copy_
(
columnwise_scale_inv
[
0
:
h
,
0
:
w
])
def
__repr__
(
self
):
if
self
.
_rowwise_data
is
not
None
:
data
=
self
.
dequantize
()
...
...
@@ -244,3 +297,55 @@ class Float8BlockwiseQTensorBase:
f
"fp8_dtype=
{
self
.
_fp8_dtype
}
, "
f
"
{
descriptor
}
_scaled_data=
{
data
}
"
)
def
update_usage
(
self
,
rowwise_usage
:
Optional
[
bool
]
=
None
,
columnwise_usage
:
Optional
[
bool
]
=
None
):
"""
update_usage can be used to clear out one of two possible copies of the data.
"""
if
rowwise_usage
is
None
:
rowwise_usage
=
self
.
_rowwise_data
is
not
None
if
columnwise_usage
is
None
:
columnwise_usage
=
self
.
_columnwise_data
is
not
None
assert
(
columnwise_usage
or
rowwise_usage
),
"Must retain some data either columnwise or rowwise"
if
columnwise_usage
and
rowwise_usage
:
if
not
self
.
_is_2D_scaled
:
# For 1D scaling, we cannot create columnwise data/scale_inv from rowwise
# data/scale_inv because their scale values are different.
assert
(
self
.
_rowwise_data
is
not
None
and
self
.
_rowwise_scale_inv
is
not
None
and
self
.
_columnwise_data
is
not
None
and
self
.
_columnwise_scale_inv
is
not
None
),
"Cannot update to rowwise and columnwise usage."
else
:
# For 2D scaling, if columnwise data/scale_inv is None, we can create them from
# rowwise data/scale_inv.
assert
(
self
.
_rowwise_data
is
not
None
and
self
.
_rowwise_scale_inv
is
not
None
),
"Cannot update to rowwise and columnwise usage because rowwise data is None."
if
self
.
_columnwise_data
is
None
or
self
.
_columnwise_scale_inv
is
None
:
self
.
_create_columnwise
()
return
if
rowwise_usage
:
assert
(
self
.
_rowwise_data
is
not
None
and
self
.
_rowwise_scale_inv
is
not
None
),
"Cannot update to rowwise usage."
self
.
_columnwise_data
=
None
self
.
_columnwise_scale_inv
=
None
return
if
columnwise_usage
:
assert
(
self
.
_columnwise_data
is
not
None
and
self
.
_columnwise_scale_inv
is
not
None
),
"Cannot update to columnwise usage."
self
.
_rowwise_data
=
None
self
.
_rowwise_scale_inv
=
None
return
return
transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py
View file @
f8c2af4c
...
...
@@ -12,10 +12,14 @@ import torch
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
DType
as
TE_DType
from
..quantized_tensor
import
QuantizedTensorBase
from
...constants
import
TE_DType
as
torch_to_transformer_engine_dtype
from
..quantized_tensor
import
Quantizer
from
...utils
import
is_non_tn_fp8_gemm_supported
,
_empty_tensor
class
_FromFloat8Func
(
torch
.
autograd
.
Function
):
"""Cast from FP8 to other dtype"""
...
...
@@ -48,7 +52,7 @@ class _FromFloat8Func(torch.autograd.Function):
return
grad
,
None
class
Float8TensorBase
:
class
Float8TensorBase
(
QuantizedTensorBase
)
:
"""Mixin class that holds data attributes of Float8Tensor.
Float8Tensor inherits from the PyTorch tensor class and this mixin
...
...
@@ -90,6 +94,13 @@ class Float8TensorBase:
return
instance
def
clear
(
self
):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
for
t
in
(
self
.
_data
,
self
.
_transpose
,
self
.
_scale_inv
):
if
t
is
not
None
:
t
.
data
=
_empty_tensor
()
self
.
_transpose_invalid
=
True
def
get_metadata
(
self
)
->
Dict
[
str
,
Any
]:
"""Get this tensor's metadata."""
return
{
...
...
@@ -100,9 +111,12 @@ class Float8TensorBase:
"quantizer"
:
self
.
_quantizer
,
}
def
prepare_for_saving
(
self
)
->
Tuple
[
list
[
Optional
[
torch
.
Tensor
]],
Float8
TensorBase
]:
def
prepare_for_saving
(
self
)
->
Tuple
[
list
[
Optional
[
torch
.
Tensor
]],
Quantized
TensorBase
]:
"""Prepare the tensor base for saving for backward"""
tensors
=
[
self
.
_data
,
self
.
_transpose
]
tensors
=
[
self
.
_data
,
self
.
_transpose
,
self
.
_scale_inv
]
self
.
_data
=
None
self
.
_transpose
=
None
self
.
_scale_inv
=
None
return
tensors
,
self
def
restore_from_saved
(
...
...
@@ -111,7 +125,8 @@ class Float8TensorBase:
"""Restore the tensor base data from the saved tensors list"""
self
.
_data
=
tensors
[
0
]
self
.
_transpose
=
tensors
[
1
]
return
tensors
[
2
:]
self
.
_scale_inv
=
tensors
[
2
]
return
tensors
[
3
:]
def
get_data_tensors
(
self
):
"""Get this Tensor's data."""
...
...
@@ -144,3 +159,43 @@ class Float8TensorBase:
data
=
data
.
contiguous
()
self
.
_transpose
=
tex
.
fp8_transpose
(
data
,
self
.
_fp8_dtype
,
out
=
self
.
_transpose
)
self
.
_transpose_invalid
=
False
def
update_usage
(
self
,
rowwise_usage
:
Optional
[
bool
]
=
None
,
columnwise_usage
:
Optional
[
bool
]
=
None
,
):
"""
Generate or remove FP8 data based on provided usage. For
FP8, data cannot be generated even if transpose is available.
"""
has_data
=
self
.
_data
is
not
None
has_data_transpose
=
self
.
_transpose
is
not
None
and
not
self
.
_transpose_invalid
needs_data
=
has_data
needs_data_transpose
=
has_data_transpose
if
is_non_tn_fp8_gemm_supported
():
if
rowwise_usage
is
not
None
and
rowwise_usage
:
needs_data
=
True
if
columnwise_usage
is
not
None
and
columnwise_usage
:
needs_data
=
True
needs_data_transpose
=
False
else
:
if
rowwise_usage
is
not
None
:
needs_data
=
rowwise_usage
if
columnwise_usage
is
not
None
:
needs_data_transpose
=
columnwise_usage
# Generate data that is required
if
needs_data
and
not
has_data
:
raise
RuntimeError
(
"Cannot generate FP8 data, even from FP8 data transpose"
)
if
needs_data_transpose
and
not
has_data_transpose
:
if
not
has_data
:
raise
RuntimeError
(
"FP8 data is required to generate FP8 data transpose"
)
self
.
_create_transpose
()
# Delete data that is not required
if
not
needs_data
:
self
.
_data
=
None
if
not
needs_data_transpose
:
self
.
_transpose
=
None
self
.
_transpose_invalid
=
True
transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py
View file @
f8c2af4c
...
...
@@ -11,10 +11,14 @@ import torch
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
DType
as
TE_DType
from
..quantized_tensor
import
QuantizedTensorBase
from
...constants
import
TE_DType
as
torch_to_transformer_engine_dtype
from
..quantized_tensor
import
Quantizer
from
...utils
import
_empty_tensor
class
_FromMXFP8Func
(
torch
.
autograd
.
Function
):
"""Cast from MXFP8 to other dtype"""
...
...
@@ -43,7 +47,7 @@ class _FromMXFP8Func(torch.autograd.Function):
return
grad
,
None
class
MXFP8TensorBase
:
class
MXFP8TensorBase
(
QuantizedTensorBase
)
:
"""Mixin class that holds data attributes of MXFP8Tensor.
MXFP8Tensor inherits from the PyTorch tensor class and this mixin
...
...
@@ -81,6 +85,17 @@ class MXFP8TensorBase:
return
instance
def
clear
(
self
):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
for
t
in
(
self
.
_rowwise_data
,
self
.
_columnwise_data
,
self
.
_rowwise_scale_inv
,
self
.
_columnwise_scale_inv
,
):
if
t
is
not
None
:
t
.
data
=
_empty_tensor
()
def
get_metadata
(
self
)
->
Dict
[
str
,
Any
]:
"""Get this tensor's metadata."""
return
{
...
...
@@ -94,7 +109,16 @@ class MXFP8TensorBase:
def
prepare_for_saving
(
self
)
->
Tuple
[
list
[
Optional
[
torch
.
Tensor
]],
MXFP8TensorBase
]:
"""Prepare the tensor base for saving for backward"""
tensors
=
[
self
.
_rowwise_data
,
self
.
_columnwise_data
]
tensors
=
[
self
.
_rowwise_data
,
self
.
_columnwise_data
,
self
.
_rowwise_scale_inv
,
self
.
_columnwise_scale_inv
,
]
self
.
_rowwise_data
=
None
self
.
_columnwise_data
=
None
self
.
_rowwise_scale_inv
=
None
self
.
_columnwise_scale_inv
=
None
return
tensors
,
self
def
restore_from_saved
(
...
...
@@ -103,7 +127,9 @@ class MXFP8TensorBase:
"""Restore the tensor base data from the saved tensors list."""
self
.
_rowwise_data
=
tensors
[
0
]
self
.
_columnwise_data
=
tensors
[
1
]
return
tensors
[
2
:]
self
.
_rowwise_scale_inv
=
tensors
[
2
]
self
.
_columnwise_scale_inv
=
tensors
[
3
]
return
tensors
[
4
:]
def
get_data_tensors
(
self
):
"""Get this Tensor's data."""
...
...
@@ -129,3 +155,48 @@ class MXFP8TensorBase:
f
"rowwise_scale_inv=
{
self
.
_rowwise_scale_inv
}
, "
")"
)
def
update_usage
(
self
,
rowwise_usage
:
Optional
[
bool
]
=
None
,
columnwise_usage
:
Optional
[
bool
]
=
None
,
):
"""
For MXFP8, columnwise scaled output is only produced by x2
scaling kernels, so this function only disables usages.
"""
# Default usage is based on available data
if
rowwise_usage
is
None
:
rowwise_usage
=
self
.
_rowwise_data
is
not
None
if
columnwise_usage
is
None
:
columnwise_usage
=
self
.
_columnwise_data
is
not
None
# Update row-scaled data
if
rowwise_usage
:
if
self
.
_rowwise_data
is
None
:
raise
RuntimeError
(
"Requested row-wise usage, but MXFP8Tensor is missing row-scaled FP8 data"
)
if
self
.
_rowwise_scale_inv
is
None
:
raise
RuntimeError
(
"Requested row-wise usage, but MXFP8Tensor is missing row-scaled scale-inverses"
)
else
:
self
.
_rowwise_data
=
None
self
.
_rowwise_scale_inv
=
None
# Update column-scaled data
if
columnwise_usage
:
if
self
.
_columnwise_data
is
None
:
raise
RuntimeError
(
"Requested column-wise usage, but MXFP8Tensor is missing column-scaled FP8 data"
)
if
self
.
_columnwise_scale_inv
is
None
:
raise
RuntimeError
(
"Requested column-wise usage, "
"but MXFP8Tensor is missing column-scaled scale-inverses"
)
else
:
self
.
_columnwise_data
=
None
self
.
_columnwise_scale_inv
=
None
transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
View file @
f8c2af4c
...
...
@@ -309,47 +309,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
# pylint: disable=missing-function-docstring
return
Float8BlockwiseQTensor
.
make_like
(
self
)
def
update_usage
(
self
,
rowwise_usage
:
Optional
[
bool
]
=
None
,
columnwise_usage
:
Optional
[
bool
]
=
None
):
"""
update_usage can be used to clear out one of two possible copies of the data.
"""
if
rowwise_usage
is
None
:
rowwise_usage
=
self
.
_rowwise_data
is
not
None
if
columnwise_usage
is
None
:
columnwise_usage
=
self
.
_columnwise_data
is
not
None
assert
(
columnwise_usage
or
rowwise_usage
),
"Must retain some data either columnwise or rowwise"
if
columnwise_usage
and
rowwise_usage
:
assert
(
self
.
_rowwise_data
is
not
None
and
self
.
_rowwise_scale_inv
is
not
None
and
self
.
_columnwise_data
is
not
None
and
self
.
_columnwise_scale_inv
is
not
None
),
"Cannot update to rowwise and columnwise usage."
return
if
rowwise_usage
:
assert
(
self
.
_rowwise_data
is
not
None
and
self
.
_rowwise_scale_inv
is
not
None
),
"Cannot update to rowwise usage."
self
.
_columnwise_data
=
None
self
.
_columnwise_scale_inv
=
None
return
if
columnwise_usage
:
assert
(
self
.
_columnwise_data
is
not
None
and
self
.
_columnwise_scale_inv
is
not
None
),
"Cannot update to columnwise usage."
self
.
_rowwise_data
=
None
self
.
_rowwise_scale_inv
=
None
return
return
def
clone
(
self
)
->
Float8BlockwiseQTensor
:
# pylint: disable=missing-function-docstring
rowwise_data
=
None
...
...
@@ -421,11 +380,6 @@ class Float8BlockwiseQTensor(Float8BlockwiseQTensorBase, QuantizedTensor):
return
self
raise
ValueError
(
"Float8BlockwiseQTensor does not support different memory formats!"
)
def
clear
(
self
):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
self
.
_rowwise_data
=
torch
.
Tensor
()
if
self
.
_rowwise_data
is
not
None
else
None
self
.
_columnwise_data
=
torch
.
Tensor
()
if
self
.
_columnwise_data
is
not
None
else
None
@
classmethod
def
_make_in_reduce_ex
(
cls
,
...
...
@@ -544,14 +498,64 @@ class _ViewFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring
# Return input tensor if shape is not provided
if
ctx
is
not
None
:
ctx
.
shape
=
tensor
.
shape
ctx
.
shape
=
tensor
.
shape
if
shape
is
None
:
return
tensor
if
list
(
shape
)
!=
list
(
tensor
.
shape
):
raise
NotImplementedError
(
"View not implemented."
)
return
tensor
# Canonicalize shape
if
not
isinstance
(
shape
,
Iterable
):
shape
=
[
shape
]
elif
len
(
shape
)
==
1
and
isinstance
(
shape
[
0
],
Iterable
):
shape
=
shape
[
0
]
if
-
1
in
shape
:
shape
=
list
(
shape
)
d_inferred
=
-
math
.
prod
(
ctx
.
shape
)
//
math
.
prod
(
shape
)
for
i
,
d
in
enumerate
(
shape
):
if
d
==
-
1
:
shape
[
i
]
=
d_inferred
break
if
tensor
.
_is_2D_scaled
:
# For the case of 2D scaled tensor, the last 2 dimensions should not change
if
shape
[
-
1
]
!=
ctx
.
shape
[
-
1
]
or
shape
[
-
2
]
!=
ctx
.
shape
[
-
2
]:
raise
RuntimeError
(
"2D scaled Float8BlockwiseQTensor does not support view "
"the last 2 dimensions "
f
"(attempted to view dims=
{
tuple
(
tensor
.
shape
)
}
to
{
tuple
(
shape
)
}
)"
)
else
:
# For the case of 1D scaled tensor, the last dimension should not change
if
shape
[
-
1
]
!=
ctx
.
shape
[
-
1
]:
raise
RuntimeError
(
"1D scaled Float8BlockwiseQTensor does not support view "
"the last dimension "
f
"(attempted to view dims=
{
tuple
(
tensor
.
shape
)
}
to
{
tuple
(
shape
)
}
)"
)
if
list
(
shape
)
==
list
(
tensor
.
shape
):
return
tensor
# Construct new tensor if shape is provided
new_rowwise_data
=
None
new_columnwise_data
=
None
if
tensor
.
_rowwise_data
is
not
None
:
new_rowwise_data
=
tensor
.
_rowwise_data
.
view
(
*
shape
)
if
tensor
.
_columnwise_data
is
not
None
:
columnwise_shape
=
[
shape
[
-
1
]]
+
list
(
shape
[:
-
1
])
new_columnwise_data
=
tensor
.
_columnwise_data
.
view
(
columnwise_shape
)
return
Float8BlockwiseQTensor
(
shape
=
shape
,
dtype
=
tensor
.
dtype
,
fp8_dtype
=
tensor
.
_fp8_dtype
,
rowwise_data
=
new_rowwise_data
,
rowwise_scale_inv
=
tensor
.
_rowwise_scale_inv
,
columnwise_data
=
new_columnwise_data
,
columnwise_scale_inv
=
tensor
.
_columnwise_scale_inv
,
quantizer
=
tensor
.
_quantizer
,
is_2D_scaled
=
tensor
.
_is_2D_scaled
,
requires_grad
=
tensor
.
requires_grad
,
)
@
staticmethod
def
backward
(
...
...
@@ -561,7 +565,27 @@ class _ViewFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring
if
isinstance
(
grad
,
Float8BlockwiseQTensor
):
raise
NotImplementedError
(
"View bwd not implemented"
)
new_data
=
(
grad
.
_rowwise_data
.
view
(
*
ctx
.
shape
)
if
grad
.
_rowwise_data
is
not
None
else
None
)
if
grad
.
_columnwise_data
is
not
None
:
columnwise_shape
=
[
ctx
.
shape
[
-
1
]]
+
list
(
ctx
.
shape
[:
-
1
])
new_columnwise_data
=
grad
.
_columnwise_data
.
view
(
columnwise_shape
)
else
:
new_columnwise_data
=
None
dgrad
=
Float8BlockwiseQTensor
(
shape
=
ctx
.
shape
,
dtype
=
grad
.
dtype
,
rowwise_data
=
new_data
,
rowwise_scale_inv
=
grad
.
_rowwise_scale_inv
,
columnwise_data
=
new_columnwise_data
,
columnwise_scale_inv
=
grad
.
_columnwise_scale_inv
,
fp8_dtype
=
grad
.
_fp8_dtype
,
quantizer
=
grad
.
_quantizer
,
is_2D_scaled
=
grad
.
_is_2D_scaled
,
requires_grad
=
grad
.
requires_grad
,
)
return
dgrad
,
None
return
grad
.
view
(
ctx
.
shape
),
None
...
...
@@ -581,8 +605,7 @@ class _ReshapeFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring
# Return input tensor if shape is not provided
if
ctx
is
not
None
:
ctx
.
shape
=
tensor
.
shape
ctx
.
shape
=
tensor
.
shape
if
shape
is
None
:
return
tensor
...
...
@@ -598,9 +621,47 @@ class _ReshapeFunc(torch.autograd.Function):
if
d
==
-
1
:
shape
[
i
]
=
d_inferred
break
if
list
(
shape
)
!=
list
(
tensor
.
shape
):
raise
NotImplementedError
(
"Reshape not implemented yet."
)
return
tensor
if
tensor
.
_is_2D_scaled
:
# For the case of 2D scaled tensor, the last 2 dimensions should not change
if
shape
[
-
1
]
!=
ctx
.
shape
[
-
1
]
or
shape
[
-
2
]
!=
ctx
.
shape
[
-
2
]:
raise
RuntimeError
(
"2D scaled Float8BlockwiseQTensor does not support reshaping "
"the last 2 dimensions "
f
"(attempted to reshape dims=
{
tuple
(
tensor
.
shape
)
}
to
{
tuple
(
shape
)
}
)"
)
else
:
# For the case of 1D scaled tensor, the last dimension should not change
if
shape
[
-
1
]
!=
ctx
.
shape
[
-
1
]:
raise
RuntimeError
(
"1D scaled Float8BlockwiseQTensor does not support reshaping "
"the last dimension "
f
"(attempted to reshape dims=
{
tuple
(
tensor
.
shape
)
}
to
{
tuple
(
shape
)
}
)"
)
if
list
(
shape
)
==
list
(
tensor
.
shape
):
return
tensor
# Construct new tensor if shape is provided
new_rowwise_data
=
None
new_columnwise_data
=
None
if
tensor
.
_rowwise_data
is
not
None
:
new_rowwise_data
=
tensor
.
_rowwise_data
.
reshape
(
*
shape
)
if
tensor
.
_columnwise_data
is
not
None
:
columnwise_shape
=
[
shape
[
-
1
]]
+
list
(
shape
[:
-
1
])
new_columnwise_data
=
tensor
.
_columnwise_data
.
view
(
columnwise_shape
)
return
Float8BlockwiseQTensor
(
shape
=
shape
,
dtype
=
tensor
.
dtype
,
fp8_dtype
=
tensor
.
_fp8_dtype
,
rowwise_data
=
new_rowwise_data
,
rowwise_scale_inv
=
tensor
.
_rowwise_scale_inv
,
columnwise_data
=
new_columnwise_data
,
columnwise_scale_inv
=
tensor
.
_columnwise_scale_inv
,
quantizer
=
tensor
.
_quantizer
,
is_2D_scaled
=
tensor
.
_is_2D_scaled
,
requires_grad
=
tensor
.
requires_grad
,
)
@
staticmethod
def
backward
(
...
...
@@ -610,5 +671,24 @@ class _ReshapeFunc(torch.autograd.Function):
# pylint: disable=missing-function-docstring
if
isinstance
(
grad
,
Float8BlockwiseQTensor
):
raise
NotImplementedError
(
"Reshape bwd not implemented yet."
)
new_rowwise_data
=
None
new_columnwise_data
=
None
if
grad
.
_rowwise_data
is
not
None
:
new_rowwise_data
=
grad
.
_rowwise_data
.
view
(
*
ctx
.
shape
)
if
grad
.
_columnwise_data
is
not
None
:
columnwise_shape
=
[
ctx
.
shape
[
-
1
]]
+
list
(
ctx
.
shape
[:
-
1
])
new_columnwise_data
=
grad
.
_columnwise_data
.
view
(
columnwise_shape
)
dgrad
=
Float8BlockwiseQTensor
(
shape
=
ctx
.
shape
,
dtype
=
grad
.
dtype
,
rowwise_data
=
new_rowwise_data
,
rowwise_scale_inv
=
grad
.
_rowwise_scale_inv
,
columnwise_data
=
new_columnwise_data
,
columnwise_scale_inv
=
grad
.
_columnwise_scale_inv
,
fp8_dtype
=
grad
.
_fp8_dtype
,
quantizer
=
grad
.
_quantizer
,
is_2D_scaled
=
grad
.
_is_2D_scaled
,
requires_grad
=
grad
.
requires_grad
,
)
return
dgrad
,
None
return
grad
.
view
(
ctx
.
shape
),
None
transformer_engine/pytorch/tensor/float8_tensor.py
View file @
f8c2af4c
...
...
@@ -11,7 +11,7 @@ import torch
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
DType
as
TE_DType
from
..utils
import
canonicalize_process_group
,
devices_match
,
non_tn_fp8_gemm_supported
from
..utils
import
canonicalize_process_group
,
devices_match
from
._internal.float8_tensor_base
import
Float8TensorBase
,
_FromFloat8Func
from
.quantized_tensor
import
QuantizedTensor
,
Quantizer
,
_IdentityFunc
from
..constants
import
dist_group_type
...
...
@@ -422,43 +422,6 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
# pylint: disable=missing-function-docstring
return
Float8Tensor
.
make_like
(
self
)
def
update_usage
(
self
,
rowwise_usage
:
Optional
[
bool
]
=
None
,
columnwise_usage
:
Optional
[
bool
]
=
None
,
):
# Figure out what data is available and what is required
has_data
=
self
.
_data
is
not
None
has_data_transpose
=
self
.
_transpose
is
not
None
and
not
self
.
_transpose_invalid
needs_data
=
has_data
needs_data_transpose
=
has_data_transpose
if
non_tn_fp8_gemm_supported
():
if
rowwise_usage
is
not
None
and
rowwise_usage
:
needs_data
=
True
if
columnwise_usage
is
not
None
and
columnwise_usage
:
needs_data
=
True
needs_data_transpose
=
False
else
:
if
rowwise_usage
is
not
None
:
needs_data
=
rowwise_usage
if
columnwise_usage
is
not
None
:
needs_data_transpose
=
columnwise_usage
# Generate data that is required
if
needs_data
and
not
has_data
:
raise
RuntimeError
(
"Cannot generate FP8 data, even from FP8 data transpose"
)
if
needs_data_transpose
and
not
has_data_transpose
:
if
not
has_data
:
raise
RuntimeError
(
"FP8 data is required to generate FP8 data transpose"
)
self
.
_create_transpose
()
# Delete data that is not required
if
not
needs_data
:
self
.
_data
=
None
if
not
needs_data_transpose
:
self
.
_transpose
=
None
self
.
_transpose_invalid
=
True
def
clone
(
self
)
->
Float8Tensor
:
# pylint: disable=missing-function-docstring
assert
self
.
_data
is
not
None
...
...
@@ -516,12 +479,6 @@ class Float8Tensor(Float8TensorBase, QuantizedTensor):
del
self
.
_transpose
# explicitly deletes the data for safety
self
.
_transpose
=
None
def
clear
(
self
):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
self
.
_data
=
torch
.
Tensor
()
if
self
.
_data
is
not
None
else
None
self
.
_transpose
=
torch
.
Tensor
()
if
self
.
_transpose
is
not
None
else
None
self
.
_transpose_invalid
=
True
@
classmethod
def
__torch_dispatch__
(
cls
,
func
,
types
,
args
,
kwargs
=
None
):
...
...
transformer_engine/pytorch/tensor/mxfp8_tensor.py
View file @
f8c2af4c
...
...
@@ -217,51 +217,6 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
# TODO(ksivamani): Fix the detach bug
return
MXFP8Tensor
.
make_like
(
self
)
def
update_usage
(
self
,
rowwise_usage
:
Optional
[
bool
]
=
None
,
columnwise_usage
:
Optional
[
bool
]
=
None
,
):
"""
For MXFP8, columnwise scaled output is only produced by x2
scaling kernels, so this function only disables usages.
"""
# Default usage is based on available data
if
rowwise_usage
is
None
:
rowwise_usage
=
self
.
_rowwise_data
is
not
None
if
columnwise_usage
is
None
:
columnwise_usage
=
self
.
_columnwise_data
is
not
None
# Update row-scaled data
if
rowwise_usage
:
if
self
.
_rowwise_data
is
None
:
raise
RuntimeError
(
"Requested row-wise usage, but MXFP8Tensor is missing row-scaled FP8 data"
)
if
self
.
_rowwise_scale_inv
is
None
:
raise
RuntimeError
(
"Requested row-wise usage, but MXFP8Tensor is missing row-scaled scale-inverses"
)
else
:
self
.
_rowwise_data
=
None
self
.
_rowwise_scale_inv
=
None
# Update column-scaled data
if
columnwise_usage
:
if
self
.
_columnwise_data
is
None
:
raise
RuntimeError
(
"Requested column-wise usage, but MXFP8Tensor is missing column-scaled FP8 data"
)
if
self
.
_columnwise_scale_inv
is
None
:
raise
RuntimeError
(
"Requested column-wise usage, "
"but MXFP8Tensor is missing column-scaled scale-inverses"
)
else
:
self
.
_columnwise_data
=
None
self
.
_columnwise_scale_inv
=
None
def
clone
(
self
)
->
MXFP8Tensor
:
# pylint: disable=missing-function-docstring
assert
self
.
_rowwise_data
is
not
None
...
...
@@ -304,11 +259,6 @@ class MXFP8Tensor(MXFP8TensorBase, QuantizedTensor):
return
self
raise
ValueError
(
"MXFP8Tensor does not support different memory formats!"
)
def
clear
(
self
):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
self
.
_rowwise_data
=
torch
.
Tensor
()
if
self
.
_rowwise_data
is
not
None
else
None
self
.
_columnwise_data
=
torch
.
Tensor
()
if
self
.
_columnwise_data
is
not
None
else
None
@
classmethod
def
__torch_dispatch__
(
cls
,
func
,
types
,
args
,
kwargs
=
None
):
...
...
transformer_engine/pytorch/tensor/quantized_tensor.py
View file @
f8c2af4c
...
...
@@ -15,9 +15,66 @@ from torch.utils._pytree import tree_map
import
transformer_engine_torch
as
tex
class
QuantizedTensorBase
:
r
"""Base class for all *TensorBase classes.
This class (and its subclasses) are optimization for when
the full QuantizedTensor is not needed (when it is fully
contained inside torch.autograd function and not visible to
PyTorch's autograd).
When creating a new tensor type X one should create both
XTensorBase class inheriting from QuantizedTensorBase and
XTensor inheriting from XTensorBase and QuantizedTensor.
XTensorBase should contain all data members needed to
implement the functionality of the tensor, while
XTensor should only implement the functionality needed
to behave like regular torch.Tensor (liek __torch_dispatch__)."""
def
update_usage
(
self
,
rowwise_usage
:
Optional
[
bool
]
=
None
,
columnwise_usage
:
Optional
[
bool
]
=
None
,
):
r
"""
Generate or remove quantized data based on provided usage.
Parameters
----------
rowwise_usage : Optional[bool[, default = `None`
Whether to create or keep the data needed for using the tensor
in rowwise fashion (e.g. as B argument in TN GEMM). Leaving it as `None`
preserves the original value in the tensor.
columnwise_usage : Optional[bool], default = `None`
Whether to create or keep the data needed for using the tensor
in columnwise fashion (e.g. as A argument in TN GEMM). Leaving it as
`None` preserves the original value in the tensor.
"""
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
class does not implement update_usage function"
)
def
prepare_for_saving
(
self
)
->
Tuple
[
list
[
Optional
[
torch
.
Tensor
]],
QuantizedTensorBase
]:
"""Prepare the tensor base for saving for backward"""
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
class does not implement prepare_for_saving function"
)
def
restore_from_saved
(
self
,
tensors
:
list
[
Optional
[
torch
.
Tensor
]]
)
->
list
[
Optional
[
torch
.
Tensor
]]:
"""Restore the tensor base data from the saved tensors list"""
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
class does not implement restore_from_saved function"
)
def
prepare_for_saving
(
*
tensors
,
)
->
Tuple
[
list
[
Optional
[
Union
[
torch
.
Tensor
,
torch
.
nn
.
Parameter
]]],
Optional
[
Any
]]:
*
tensors
:
Union
[
torch
.
Tensor
,
QuantizedTensorBase
],
)
->
Tuple
[
list
[
Optional
[
Union
[
torch
.
Tensor
,
torch
.
nn
.
Parameter
]]],
list
[
Optional
[
QuantizedTensorBase
]]
]:
"""Prepare tensors for saving. Needed because save_for_backward accepts only
torch.Tensor/torch.nn.Parameter types, while we want to be able to save
the internal TensorBase types too."""
...
...
@@ -35,10 +92,13 @@ def prepare_for_saving(
def
restore_from_saved
(
tensors
:
list
[
Optional
[
Any
]],
tensors
:
list
[
Optional
[
Union
[
torch
.
Tensor
,
QuantizedTensorBase
]
]],
saved_tensors
:
list
[
Optional
[
Union
[
torch
.
Tensor
,
torch
.
nn
.
Parameter
]]],
return_saved_tensors
:
bool
=
False
,
)
->
list
[
Optional
[
Any
]]
|
tuple
[
list
[
Optional
[
Any
]],
list
[
Optional
[
torch
.
Tensor
]]]:
)
->
(
list
[
Optional
[
torch
.
Tensor
|
QuantizedTensorBase
]]
|
tuple
[
list
[
Optional
[
torch
.
Tensor
|
QuantizedTensorBase
]],
list
[
Optional
[
torch
.
Tensor
]]]
):
"""Recombine the tensor data and metadata during backward pass."""
tensor_objects
=
[]
for
tensor
in
tensors
:
...
...
@@ -294,21 +354,6 @@ class QuantizedTensor(torch.Tensor):
f
"
{
self
.
__class__
.
__name__
}
class does not implement detach function"
)
def
update_usage
(
self
,
rowwise_usage
:
Optional
[
bool
]
=
None
,
columnwise_usage
:
Optional
[
bool
]
=
None
,
):
"""Indicate to the tensor how it is going to be used
This enables optimizations to memory usage in some cases
where forward and backward passes use the tensor in
different directions.
"""
raise
NotImplementedError
(
f
"
{
self
.
__class__
.
__name__
}
class does not implement update_usage function"
)
def
clear
(
self
):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully"""
...
...
transformer_engine/pytorch/tensor/utils.py
View file @
f8c2af4c
...
...
@@ -6,13 +6,13 @@
import
torch
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
import
transformer_engine_torch
as
tex
from
transformer_engine_torch
import
multi_tensor_scale
,
multi_tensor_compute_scale_and_scale_inv
from
.quantized_tensor
import
QuantizedTensor
from
.float8_tensor
import
Float8Tensor
,
Float8Quantizer
,
Float8CurrentScalingQuantizer
from
.mxfp8_tensor
import
MXFP8Tensor
,
MXFP8Quantizer
from
.float8_blockwise_tensor
import
Float8BlockwiseQTensor
,
Float8BlockQuantizer
from
..optimizers.multi_tensor_apply
import
multi_tensor_applier
...
...
@@ -33,6 +33,12 @@ def replace_raw_data(tensor: QuantizedTensor, new_raw_data: torch.Tensor):
new_raw_data
.
detach
().
copy_
(
old_raw_data
)
tensor
.
_data
=
new_raw_data
del
old_raw_data
elif
isinstance
(
tensor
,
Float8BlockwiseQTensor
):
old_raw_data
=
tensor
.
_rowwise_data
assert
old_raw_data
.
dtype
==
new_raw_data
.
dtype
,
"The data types of raw data don't match"
new_raw_data
.
detach
().
copy_
(
old_raw_data
)
tensor
.
_rowwise_data
=
new_raw_data
del
old_raw_data
elif
isinstance
(
tensor
,
MXFP8Tensor
):
raise
NotImplementedError
(
"replace_raw_data for MXFP8Tensor is not supported yet"
)
else
:
...
...
@@ -66,6 +72,7 @@ def cast_master_weights_to_fp8(
delayed_scaling_params
=
[]
current_scaling_params
=
[]
blockwise_scaling_params
=
[]
if
fsdp_shard_model_weights
is
None
:
use_fsdp_shard_model_weights
=
False
...
...
@@ -107,6 +114,10 @@ def cast_master_weights_to_fp8(
current_scaling_params
.
append
(
(
model_weight
,
master_weight
,
start_offset
,
fsdp_shard_model_weight
)
)
elif
isinstance
(
quantizer
,
Float8BlockQuantizer
):
blockwise_scaling_params
.
append
(
(
model_weight
,
master_weight
,
start_offset
,
fsdp_shard_model_weight
)
)
elif
isinstance
(
quantizer
,
MXFP8Quantizer
):
raise
NotImplementedError
(
"cast_master_weights_to_fp8 for MXFP8BlockScaling is not supported yet"
...
...
@@ -124,6 +135,10 @@ def cast_master_weights_to_fp8(
_cast_master_weights_to_fp8_current_scaling
(
current_scaling_params
,
group
,
use_fsdp_shard_model_weights
)
if
len
(
blockwise_scaling_params
)
>
0
:
_cast_master_weights_to_fp8_blockwise_scaling
(
blockwise_scaling_params
,
group
,
use_fsdp_shard_model_weights
)
def
_cast_master_weights_to_fp8_delayed_scaling
(
params
,
group
,
use_fsdp_shard_model_weights
=
False
):
...
...
@@ -314,3 +329,125 @@ def _cast_master_weights_to_fp8_current_scaling(params, group, use_fsdp_shard_mo
model_weight
.
dtype
,
)
quantizer
.
update_quantized
(
master_weight
,
model_weight_fragment
)
def
_cast_master_weights_to_fp8_blockwise_scaling
(
params
,
group
,
use_fsdp_shard_model_weights
=
False
):
r
"""Helper function to cast master weights to FP8 primary weights for blockwise scaling.
Parameters
----------
params : List of tuple, each tuple contains a model weight, a master weight, and an offset
indicating the starting index of the master weight in the model weight.
group : The distributed group to do amax reduction. Typically it's the data parallel
group.
use_fsdp_shard_model_weights : bool, if True, it means that the model weights are sharded.
"""
# Parameter attributes
device
=
params
[
0
][
0
].
device
block_len
=
params
[
0
][
0
].
_get_quantizer
().
block_len
fp8_dtype
=
params
[
0
][
0
].
_get_quantizer
().
dtype
force_pow_2_scales
=
params
[
0
][
0
].
_get_quantizer
().
force_pow_2_scales
amax_epsilon
=
params
[
0
][
0
].
_get_quantizer
().
amax_epsilon
# Create a dummy overflow buffer, it's needed by multi_tensor_applier.
dummy_overflow_buf
=
torch
.
zeros
(
1
,
dtype
=
torch
.
int
,
device
=
device
)
# Get the total number of amax elements in all the model weights.
cu_amax_sizes
=
[
0
]
for
model_weight
,
_
,
_
,
_
in
params
:
scale_shape
=
model_weight
.
_get_quantizer
().
get_scale_shape
(
model_weight
.
shape
,
False
)
num_amaxes
=
scale_shape
[
0
]
*
scale_shape
[
1
]
cu_amax_sizes
.
append
(
cu_amax_sizes
[
-
1
]
+
num_amaxes
)
# Create a contiguous buffer to store amaxes temporarily, so we can perform all all-reduce
# NCCL kernels at once.
packed_amaxes
=
torch
.
zeros
(
cu_amax_sizes
[
-
1
],
dtype
=
torch
.
float32
,
device
=
device
)
# ---------------------------------------------------------------------------------------------
# Step 1: Iterate through all the none empty master weights and compute amax of them. Store the
# amaxes in a contiguous buffer. If a block of a master weight is empty, the
# corresponding amax will be set to 0.
# ---------------------------------------------------------------------------------------------
amaxes
,
scales
,
scale_invs
=
[],
[],
[]
for
i
,
(
model_weight
,
master_weight
,
start_offset
,
_
)
in
enumerate
(
params
):
# Make sure all the model weights have the same numerical options.
quantizer
=
model_weight
.
_get_quantizer
()
assert
block_len
==
quantizer
.
block_len
assert
fp8_dtype
==
quantizer
.
dtype
assert
force_pow_2_scales
==
quantizer
.
force_pow_2_scales
assert
amax_epsilon
==
quantizer
.
amax_epsilon
scale_shape
=
quantizer
.
get_scale_shape
(
model_weight
.
shape
,
False
)
amax
=
packed_amaxes
[
cu_amax_sizes
[
i
]
:
cu_amax_sizes
[
i
+
1
]].
reshape
(
scale_shape
)
scale
=
torch
.
empty
(
scale_shape
,
dtype
=
torch
.
float32
,
device
=
device
)
scale_inv
=
model_weight
.
_rowwise_scale_inv
assert
len
(
scale_shape
)
==
2
assert
len
(
scale_inv
.
shape
)
==
2
assert
scale_inv
.
shape
[
0
]
==
scale_shape
[
0
]
assert
scale_inv
.
shape
[
1
]
==
scale_shape
[
1
]
amaxes
.
append
(
amax
)
scales
.
append
(
scale
)
scale_invs
.
append
(
scale_inv
)
# Compute amax of the master weight and store it in packed_amaxes.
if
master_weight
is
not
None
:
assert
len
(
model_weight
.
shape
)
==
2
h
,
w
=
model_weight
.
shape
tex
.
fp8_block_scaling_compute_partial_amax
(
master_weight
,
amax
,
h
,
w
,
start_offset
,
block_len
)
# ---------------------------------------------------------------------------------------------
# Step 2: Perform all-reduce on packed_amaxes to get the global amax.
# ---------------------------------------------------------------------------------------------
torch
.
distributed
.
all_reduce
(
packed_amaxes
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
group
)
# ---------------------------------------------------------------------------------------------
# Step 3: Update scales and scale_invs.
# ---------------------------------------------------------------------------------------------
if
fp8_dtype
==
tex
.
DType
.
kFloat8E4M3
:
max_fp8
=
448.0
elif
fp8_dtype
==
tex
.
DType
.
kFloat8E5M2
:
max_fp8
=
57344.0
else
:
raise
ValueError
(
f
"Unsupported FP8 dtype:
{
fp8_dtype
}
"
)
multi_tensor_applier
(
multi_tensor_compute_scale_and_scale_inv
,
dummy_overflow_buf
,
[
amaxes
,
scales
,
scale_invs
],
max_fp8
,
force_pow_2_scales
,
amax_epsilon
,
)
# ---------------------------------------------------------------------------------------------
# Step 4: Cast master weights to FP8.
# ---------------------------------------------------------------------------------------------
for
(
model_weight
,
master_weight
,
start_offset
,
model_weight_fragment
),
scale
in
zip
(
params
,
scales
):
# Clear columnwise data for all model weights.
# We cannot create columnwise data here because users (like megatron) may want to overlap
# the all-gather of model weights and forward process, so the model weight is not updated
# at this moment.
model_weight
.
update_usage
(
rowwise_usage
=
True
,
columnwise_usage
=
False
)
# If master weight is None, it means that the master weight of the current model weight
# is in other DP ranks.
if
master_weight
is
None
:
continue
# Cast master weight to FP8
end_offset
=
start_offset
+
master_weight
.
numel
()
if
not
use_fsdp_shard_model_weights
:
model_weight_fragment
=
model_weight
.
_rowwise_data
.
reshape
(
-
1
)[
start_offset
:
end_offset
]
assert
len
(
model_weight
.
shape
)
==
2
h
,
w
=
model_weight
.
shape
tex
.
fp8_block_scaling_partial_cast
(
master_weight
,
model_weight_fragment
,
scale
,
h
,
w
,
start_offset
,
block_len
,
fp8_dtype
)
transformer_engine/pytorch/transformer.py
View file @
f8c2af4c
...
...
@@ -10,13 +10,11 @@ from typing import Callable, List, Optional, Tuple, Union
import
torch
from
transformer_engine.pytorch
import
torch_version
from
transformer_engine.pytorch.module
import
LayerNormMLP
,
LayerNorm
,
RMSNorm
from
transformer_engine.debug.pytorch.debug_state
import
TEDebugState
from
transformer_engine.pytorch.attention
import
(
MultiheadAttention
,
)
from
transformer_engine.pytorch.dot_product_attention.inference
import
InferenceParams
from
transformer_engine.pytorch.dot_product_attention.utils
import
check_set_window_size
from
transformer_engine.pytorch.attention.multi_head_attention
import
MultiheadAttention
from
transformer_engine.pytorch.attention.inference
import
InferenceParams
from
transformer_engine.pytorch.jit
import
(
set_jit_fusion_options
,
warmup_jit_bias_dropout_add_all_dtypes
,
...
...
@@ -27,6 +25,7 @@ from transformer_engine.pytorch.jit import (
from
transformer_engine.pytorch.utils
import
(
cast_if_needed
,
get_default_init_method
,
torch_get_autocast_gpu_dtype
,
)
from
transformer_engine.pytorch.constants
import
(
AttnMaskTypes
,
...
...
@@ -169,6 +168,8 @@ class TransformerLayer(torch.nn.Module):
interpretation is that the individual `q`, `k`, and `v` weights for each
attention head are interleaved. This parameter is set to `False` when
using :attr:`fuse_qkv_params=False`.
rotary_pos_interleaved : bool, default = `False`
whether to use interleaved rotary position embeddings.
bias : bool, default = `True`
if set to `False`, the transformer layer will not learn any additive biases.
activation : str, default = 'gelu'
...
...
@@ -268,6 +269,7 @@ class TransformerLayer(torch.nn.Module):
drop_path_rate
:
float
=
0.0
,
set_parallel_mode
:
bool
=
False
,
fuse_qkv_params
:
bool
=
False
,
rotary_pos_interleaved
:
bool
=
False
,
zero_centered_gamma
:
bool
=
False
,
qkv_weight_interleaved
:
bool
=
True
,
ub_tp_comm_overlap
:
bool
=
False
,
...
...
@@ -286,11 +288,9 @@ class TransformerLayer(torch.nn.Module):
super
().
__init__
()
self
.
self_attn_mask_type
=
self_attn_mask_type
self
.
window_size
=
check_set_
window_size
(
self_attn_mask_type
,
window_size
)
self
.
window_size
=
window_size
self
.
enc_dec_attn_mask_type
=
enc_dec_attn_mask_type
self
.
enc_dec_window_size
=
check_set_window_size
(
enc_dec_attn_mask_type
,
enc_dec_window_size
)
self
.
enc_dec_window_size
=
enc_dec_window_size
params_dtype
=
torch
.
get_default_dtype
()
if
params_dtype
is
None
else
params_dtype
ub_bulk_wgrad
=
ub_tp_comm_overlap
and
ub_bulk_wgrad
ub_bulk_dgrad
=
ub_tp_comm_overlap
and
ub_bulk_dgrad
...
...
@@ -366,6 +366,7 @@ class TransformerLayer(torch.nn.Module):
"fuse_qkv_params"
:
fuse_qkv_params
,
"zero_centered_gamma"
:
zero_centered_gamma
,
"qkv_weight_interleaved"
:
qkv_weight_interleaved
,
"rotary_pos_interleaved"
:
rotary_pos_interleaved
,
"ub_bulk_wgrad"
:
ub_bulk_wgrad
,
"ub_bulk_dgrad"
:
ub_bulk_dgrad
,
"ub_overlap_ag"
:
ub_overlap_ag
,
...
...
@@ -440,9 +441,7 @@ class TransformerLayer(torch.nn.Module):
self
.
drop_path
=
DropPath
(
drop_path_rate
)
if
drop_path_rate
>
0.0
else
None
# Set bias+dropout+add fusion grad_enable execution handler.
TORCH_MAJOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
0
])
TORCH_MINOR
=
int
(
torch
.
__version__
.
split
(
"."
)[
1
])
use_nvfuser
=
TORCH_MAJOR
>
1
or
(
TORCH_MAJOR
==
1
and
TORCH_MINOR
>=
10
)
use_nvfuser
=
torch_version
()
>=
(
1
,
10
,
0
)
and
torch_version
()
<
(
2
,
2
,
0
)
self
.
bias_dropout_add_exec_handler
=
nullcontext
if
use_nvfuser
else
torch
.
enable_grad
if
self
.
bias_dropout_fusion
:
...
...
@@ -657,12 +656,10 @@ class TransformerLayer(torch.nn.Module):
self_attn_mask_type
=
self
.
self_attn_mask_type
if
window_size
is
None
:
window_size
=
self
.
window_size
window_size
=
check_set_window_size
(
self_attn_mask_type
,
window_size
)
if
enc_dec_attn_mask_type
is
None
:
enc_dec_attn_mask_type
=
self
.
enc_dec_attn_mask_type
if
enc_dec_window_size
is
None
:
enc_dec_window_size
=
self
.
enc_dec_window_size
enc_dec_window_size
=
check_set_window_size
(
enc_dec_attn_mask_type
,
enc_dec_window_size
)
assert
(
self_attn_mask_type
in
AttnMaskTypes
...
...
@@ -694,7 +691,7 @@ class TransformerLayer(torch.nn.Module):
# For AMP
if
torch
.
is_autocast_enabled
():
hidden_states
=
cast_if_needed
(
hidden_states
,
torch
.
get_autocast_gpu_dtype
())
hidden_states
=
cast_if_needed
(
hidden_states
,
torch
_
get_autocast_gpu_dtype
())
# Self attention.
self_attention_outputs
=
self
.
self_attention
(
...
...
transformer_engine/pytorch/triton/cross_entropy.py
View file @
f8c2af4c
...
...
@@ -95,6 +95,7 @@ def cross_entropy_kernel(
m_d_X_y_stride
,
rank
,
world_size
,
ignore_idx
,
n_cols
,
n_non_ignore
,
label_smoothing
:
tl
.
constexpr
,
...
...
@@ -114,6 +115,7 @@ def cross_entropy_kernel(
m_d_X_y_stride: The stride of m/d/X_y tensor.
rank (int): The rank of this device in the TP group.
world_size (int): The size of world involved in this distributed loss calculation.
ignore_idx (int): Tokens to be ignored for loss and gradient calculation.
n_cols (int): The number of columns in the input tensor.
n_non_ignore (int): The number of non-ignored elements in the batch.
label_smoothing (float): The amount of smoothing when computing the loss, where 0.0 means no smoothing.
...
...
@@ -129,6 +131,13 @@ def cross_entropy_kernel(
Y_ptr
+=
program_id
*
Y_stride
y
=
tl
.
load
(
Y_ptr
)
if
y
==
ignore_idx
:
# set all X_ptr as 0
for
i
in
range
(
0
,
n_cols
,
BLOCK_SIZE
):
X_offsets
=
i
+
tl
.
arange
(
0
,
BLOCK_SIZE
)
tl
.
store
(
X_ptr
+
X_offsets
,
0.0
,
mask
=
X_offsets
<
n_cols
)
return
loss_ptr
+=
program_id
*
loss_stride
m_d_X_y_ptr
+=
program_id
*
3
*
m_d_X_y_stride
...
...
@@ -248,6 +257,7 @@ def cross_entropy_forward(
label_smoothing
:
float
,
reduce_loss
:
bool
,
dist_process_group
:
Union
[
dist
.
ProcessGroup
,
None
],
ignore_idx
:
int
,
):
"""Forward implementation of Cross Entropy kernel"""
...
...
@@ -306,6 +316,7 @@ def cross_entropy_forward(
m_d_X_y_stride
=
m_d_X_y_gathered
.
stride
(
-
1
),
rank
=
rank
,
world_size
=
world_size
,
ignore_idx
=
ignore_idx
,
n_cols
=
V
,
n_non_ignore
=
n_rows
,
label_smoothing
=
label_smoothing
,
...
...
transformer_engine/pytorch/utils.py
View file @
f8c2af4c
...
...
@@ -7,13 +7,13 @@ from __future__ import annotations
import
functools
import
math
import
os
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
torch
import
transformer_engine.pytorch.cpp_extensions
as
ext
from
.
import
torch_version
from
..debug.pytorch.debug_quantization
import
DebugQuantizedTensor
from
.tensor.quantized_tensor
import
QuantizedTensor
from
torch.utils.cpp_extension
import
IS_HIP_EXTENSION
def
requires_grad
(
*
tensors
:
Tuple
[
Optional
[
torch
.
Tensor
],
...])
->
None
:
...
...
@@ -24,6 +24,12 @@ def requires_grad(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
return
False
@
functools
.
lru_cache
(
maxsize
=
None
)
def
_empty_tensor
()
->
torch
.
Tensor
:
"""Get tensor with no entries and no data"""
return
torch
.
Tensor
().
cuda
()
def
clear_tensor_data
(
*
tensors
:
Tuple
[
Optional
[
torch
.
Tensor
],
...])
->
None
:
"""
Trick to deallocate tensor memory when delete operation does not
...
...
@@ -33,17 +39,22 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None:
"""
for
t
in
tensors
:
if
t
is
not
None
:
if
isinstance
(
t
,
QuantizedTensor
):
if
hasattr
(
t
,
"clear"
):
t
.
clear
()
else
:
t
.
data
=
torch
.
T
ensor
()
t
.
data
=
_empty_t
ensor
()
del
t
@
functools
.
lru_cache
def
_get_device_compute_capability
(
device
:
torch
.
device
)
->
Tuple
[
int
,
int
]:
props
=
torch
.
cuda
.
get_device_properties
(
device
)
return
(
props
.
major
,
props
.
minor
)
def
get_device_compute_capability
()
->
Tuple
[
int
,
int
]:
"""CUDA compute capability of current GPU"""
props
=
torch
.
cuda
.
get_device_properties
(
torch
.
cuda
.
current_device
())
return
(
props
.
major
,
props
.
minor
)
return
_get_device_compute_capability
(
torch
.
cuda
.
current_device
())
def
attention_mask_func
(
...
...
@@ -155,6 +166,184 @@ def split_tensor_along_dim(
return
tensor_list
# @klakhani TODO: Consider combining with split_tensor_along_dim() and no_op_cat() and SplitAlongDim
def
combine_tensors
(
tensors
:
List
[
torch
.
Tensor
],
dim
:
int
,
)
->
torch
.
Tensor
:
"""Combine tensors along a particular dimension"""
num_tensors
=
len
(
tensors
)
new_shape
=
list
(
tensors
[
0
].
shape
)
new_shape
.
insert
(
dim
,
num_tensors
)
from
transformer_engine.pytorch.float8_tensor
import
Float8Tensor
if
isinstance
(
tensors
[
0
],
Float8Tensor
):
new_stride
=
list
(
tensors
[
0
].
_data
.
stride
())
new_stride
.
insert
(
dim
,
int
(
new_stride
[
dim
-
1
]
/
num_tensors
))
combined_tensor
=
torch
.
Tensor
().
to
(
device
=
tensors
[
0
].
device
,
dtype
=
tensors
[
0
].
_data
.
dtype
)
combined_tensor
.
set_
(
tensors
[
0
].
_data
.
untyped_storage
(),
tensors
[
0
].
_data
.
storage_offset
(),
new_shape
,
new_stride
,
)
combined_tensor
=
Float8Tensor
.
make_like
(
tensors
[
0
],
data
=
combined_tensor
,
shape
=
new_shape
)
else
:
new_stride
=
list
(
tensors
[
0
].
stride
())
new_stride
.
insert
(
dim
,
int
(
new_stride
[
dim
-
1
]
/
num_tensors
))
combined_tensor
=
torch
.
Tensor
().
to
(
device
=
tensors
[
0
].
device
,
dtype
=
tensors
[
0
].
dtype
)
combined_tensor
.
set_
(
tensors
[
0
].
untyped_storage
(),
tensors
[
0
].
storage_offset
(),
new_shape
,
new_stride
)
return
combined_tensor
class
SplitAlongDim
(
torch
.
autograd
.
Function
):
"""
Split tensor along given dimension
"""
@
staticmethod
def
forward
(
ctx
,
mixed_x_layer
:
torch
.
Tensor
,
split_dim
:
int
,
split_size_or_sections
:
Union
[
int
,
List
[
int
],
Tuple
[
int
]],
squeeze
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
...]:
# pylint: disable=missing-function-docstring
ctx
.
split_dim
=
split_dim
ctx
.
split_size_or_sections
=
split_size_or_sections
from
transformer_engine.pytorch.float8_tensor
import
Float8Tensor
from
transformer_engine.pytorch.tensor._internal.float8_tensor_base
import
Float8TensorBase
if
isinstance
(
mixed_x_layer
,
Float8TensorBase
)
and
not
isinstance
(
mixed_x_layer
,
Float8Tensor
):
return
tuple
(
Float8TensorBase
(
fp8_scale_inv
=
mixed_x_layer
.
_scale_inv
,
fp8_dtype
=
mixed_x_layer
.
_fp8_dtype
,
data
=
x
.
squeeze
(
split_dim
)
if
squeeze
else
x
,
shape
=
x
.
squeeze
(
split_dim
).
shape
if
squeeze
else
x
.
shape
,
quantizer
=
mixed_x_layer
.
_quantizer
,
)
for
x
in
torch
.
split
(
mixed_x_layer
.
_data
,
split_size_or_sections
=
split_size_or_sections
,
dim
=
split_dim
,
)
)
if
isinstance
(
mixed_x_layer
,
Float8Tensor
):
return
tuple
(
Float8Tensor
.
make_like
(
mixed_x_layer
,
data
=
x
.
squeeze
(
split_dim
)
if
squeeze
else
x
,
shape
=
x
.
squeeze
(
split_dim
).
shape
if
squeeze
else
x
.
shape
,
)
for
x
in
torch
.
split
(
mixed_x_layer
.
_data
,
split_size_or_sections
=
split_size_or_sections
,
dim
=
split_dim
,
)
)
out_list
=
torch
.
split
(
mixed_x_layer
,
split_size_or_sections
,
dim
=
split_dim
)
if
squeeze
:
out_list
=
[
x
.
squeeze
(
split_dim
)
for
x
in
out_list
]
return
out_list
@
staticmethod
def
backward
(
ctx
,
*
grad_outputs
):
# pylint: disable=missing-function-docstring
assert
len
(
grad_outputs
)
>
0
,
"No gradients received for backprop!"
if
isinstance
(
ctx
.
split_size_or_sections
,
(
list
,
tuple
)):
split_sizes
=
ctx
.
split_size_or_sections
assert
len
(
grad_outputs
)
==
len
(
split_sizes
),
"Unequal number of gradients vs split sections for backprop!"
if
isinstance
(
ctx
.
split_size_or_sections
,
int
):
split_sizes
=
[
ctx
.
split_size_or_sections
]
*
len
(
grad_outputs
)
dims
=
len
(
grad_outputs
[
0
].
shape
)
split_dim
=
(
ctx
.
split_dim
+
dims
)
%
dims
from
transformer_engine.pytorch.float8_tensor
import
Float8Tensor
if
isinstance
(
grad_outputs
[
0
],
Float8Tensor
):
noop_ok
=
True
strides
=
grad_outputs
[
0
].
stride
()
data_ptr
=
grad_outputs
[
0
].
_data
.
untyped_storage
().
data_ptr
()
shape
=
list
(
grad_outputs
[
0
].
shape
)
for
i
,
tensor
in
enumerate
(
grad_outputs
):
shape_i
=
shape
shape_i
[
split_dim
]
=
split_sizes
[
i
]
offset_size
=
sum
(
split_sizes
[:
i
])
*
np
.
prod
(
shape
[
split_dim
+
1
:])
if
(
tensor
.
stride
()
!=
strides
or
list
(
tensor
.
shape
)
!=
shape_i
or
tensor
.
_data
.
untyped_storage
().
data_ptr
()
!=
data_ptr
or
tensor
.
storage_offset
()
!=
offset_size
):
noop_ok
=
False
break
if
noop_ok
:
ret
=
torch
.
Tensor
().
to
(
device
=
grad_outputs
[
0
].
device
,
dtype
=
grad_outputs
[
0
].
_data
.
dtype
)
new_shape
=
list
(
shape
)
new_shape
[
split_dim
]
=
sum
(
split_sizes
)
ret
.
set_
(
grad_outputs
[
0
].
_data
.
untyped_storage
(),
grad_outputs
[
0
].
_data
.
storage_offset
(),
new_shape
,
strides
,
)
return
(
Float8Tensor
.
make_like
(
grad_outputs
[
0
],
data
=
ret
,
shape
=
ret
.
shape
),
None
,
None
,
)
grad_outputs_data
=
[
x
.
_data
for
x
in
grad_outputs
]
data
=
torch
.
cat
(
grad_outputs_data
,
dim
=
split_dim
)
return
(
Float8Tensor
.
make_like
(
grad_outputs
[
0
],
data
=
data
,
shape
=
data
.
shape
),
None
,
None
,
None
,
)
noop_ok
=
True
strides
=
grad_outputs
[
0
].
stride
()
data_ptr
=
grad_outputs
[
0
].
untyped_storage
().
data_ptr
()
shape
=
list
(
grad_outputs
[
0
].
shape
)
for
i
,
tensor
in
enumerate
(
grad_outputs
):
shape_i
=
shape
shape_i
[
split_dim
]
=
split_sizes
[
i
]
offset_size
=
sum
(
split_sizes
[:
i
])
*
np
.
prod
(
shape
[
split_dim
+
1
:])
if
(
tensor
.
stride
()
!=
strides
or
list
(
tensor
.
shape
)
!=
shape_i
or
tensor
.
untyped_storage
().
data_ptr
()
!=
data_ptr
or
tensor
.
storage_offset
()
!=
offset_size
):
noop_ok
=
False
break
if
noop_ok
:
ret
=
torch
.
Tensor
().
to
(
device
=
grad_outputs
[
0
].
device
,
dtype
=
grad_outputs
[
0
].
dtype
)
new_shape
=
list
(
shape
)
new_shape
[
split_dim
]
=
sum
(
split_sizes
)
ret
.
set_
(
grad_outputs
[
0
].
untyped_storage
(),
grad_outputs
[
0
].
storage_offset
(),
new_shape
,
strides
,
)
return
ret
,
None
,
None
return
torch
.
cat
(
grad_outputs
,
dim
=
split_dim
),
None
,
None
def
validate_ctx_manager
(
ctx
:
Callable
)
->
None
:
"""Checks if passed in object can be used as a context manager."""
try
:
...
...
@@ -237,10 +426,10 @@ def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None:
"""Assert that tensor or tensors dimensions are supported for FP8 TN GEMM."""
for
tensor
in
tensors
:
assert
tensor
.
dim
()
==
2
and
tensor
.
size
(
0
)
%
8
==
0
and
tensor
.
s
ize
(
1
)
%
16
==
0
,
(
"FP8 execution requires
2D input matrices with
"
"
height divisible by 8 and width divisible by 16,
"
f
"
but got tensor with
dims=
{
list
(
tensor
.
size
())
}
"
assert
math
.
prod
(
tensor
.
shape
[:
-
1
]
)
%
8
==
0
and
tensor
.
s
hape
[
-
1
]
%
16
==
0
,
(
"FP8 execution requires
the product of all dimensions except the last to be divisible
"
"
by 8 and the last dimension to be divisible by 16, but got tensor with
"
f
" dims=
{
list
(
tensor
.
size
())
}
"
)
if
IS_HIP_EXTENSION
:
...
...
@@ -273,11 +462,12 @@ def is_bf16_compatible() -> None:
return
torch
.
cuda
.
get_device_capability
()[
0
]
>=
8
def
non_tn_fp8_gemm_supported
()
->
bool
:
def
is_
non_tn_fp8_gemm_supported
()
->
bool
:
"""Checks whether the device supports
non-TN layouts for FP8 GEMMs.
"""
return
torch
.
cuda
.
get_device_capability
()
>=
(
10
,
0
)
device_capability
=
torch
.
cuda
.
get_device_capability
()
return
(
10
,
0
)
<=
device_capability
<
(
12
,
0
)
or
device_capability
>=
(
13
,
0
)
@
functools
.
lru_cache
(
maxsize
=
None
)
...
...
@@ -438,3 +628,16 @@ def canonicalize_process_group(
if
group
is
None
:
return
torch
.
distributed
.
distributed_c10d
.
_get_default_group
()
return
group
def
torch_get_autocast_gpu_dtype
()
->
torch
.
dtype
:
"""Get PyTorch autocast GPU dtype."""
if
torch_version
()
>=
(
2
,
4
,
0
):
return
torch
.
get_autocast_dtype
(
"cuda"
)
return
torch
.
get_autocast_gpu_dtype
()
if
torch_version
()
>=
(
2
,
4
,
0
):
gpu_autocast_ctx
=
functools
.
partial
(
torch
.
amp
.
autocast
,
device_type
=
"cuda"
)
else
:
gpu_autocast_ctx
=
torch
.
cuda
.
amp
.
autocast
Prev
1
…
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment