Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
TransformerEngine
Commits
27ddce40
Commit
27ddce40
authored
Oct 11, 2025
by
wenjh
Browse files
Merge branch 'nv_main'
parents
d262ef4c
5b3092a0
Changes
208
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
138 additions
and
14 deletions
+138
-14
transformer_engine/pytorch/ops/fuser.py
transformer_engine/pytorch/ops/fuser.py
+2
-0
transformer_engine/pytorch/ops/linear.py
transformer_engine/pytorch/ops/linear.py
+2
-1
transformer_engine/pytorch/setup.py
transformer_engine/pytorch/setup.py
+101
-4
transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py
.../pytorch/tensor/_internal/float8_blockwise_tensor_base.py
+5
-0
transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py
...mer_engine/pytorch/tensor/_internal/float8_tensor_base.py
+7
-2
transformer_engine/pytorch/tensor/float8_tensor.py
transformer_engine/pytorch/tensor/float8_tensor.py
+16
-6
transformer_engine/pytorch/transformer.py
transformer_engine/pytorch/transformer.py
+2
-1
transformer_engine/pytorch/triton/cross_entropy.py
transformer_engine/pytorch/triton/cross_entropy.py
+3
-0
No files found.
transformer_engine/pytorch/ops/fuser.py
View file @
27ddce40
...
...
@@ -19,6 +19,7 @@ from transformer_engine.pytorch.ops.op import (
)
from
transformer_engine.pytorch.ops.fused
import
(
fuse_backward_activation_bias
,
fuse_backward_add_rmsnorm
,
fuse_backward_linear_add
,
fuse_backward_linear_scale
,
fuse_forward_linear_bias_activation
,
...
...
@@ -371,6 +372,7 @@ class OperationFuser:
ops
=
fuse_backward_linear_add
(
ops
)
ops
=
fuse_backward_linear_scale
(
ops
)
ops
=
fuse_backward_activation_bias
(
ops
,
recipe
)
ops
=
fuse_backward_add_rmsnorm
(
ops
)
return
ops
def
maybe_fuse_ops
(
...
...
transformer_engine/pytorch/ops/linear.py
View file @
27ddce40
...
...
@@ -54,7 +54,8 @@ class Linear(FusedOperation):
weight's `main_grad` attribute instead of relying on PyTorch
autograd. The weight's `main_grad` must be set externally and
there is no guarantee that `grad` will be set or be
meaningful.
meaningful. This is primarily intented to integrate with
Megatron-LM.
"""
...
...
transformer_engine/pytorch/setup.py
View file @
27ddce40
...
...
@@ -10,14 +10,30 @@ import sys
import
os
import
shutil
from
pathlib
import
Path
import
platform
import
urllib
import
setuptools
from
wheel.bdist_wheel
import
bdist_wheel
as
_bdist_wheel
from
packaging.version
import
parse
try
:
import
torch
from
torch.utils.cpp_extension
import
BuildExtension
except
ImportError
as
e
:
raise
RuntimeError
(
"This package needs Torch to build."
)
from
e
FORCE_BUILD
=
os
.
getenv
(
"NVTE_PYTORCH_FORCE_BUILD"
,
"FALSE"
)
==
"TRUE"
FORCE_CXX11_ABI
=
os
.
getenv
(
"NVTE_PYTORCH_FORCE_CXX11_ABI"
,
"FALSE"
)
==
"TRUE"
SKIP_CUDA_BUILD
=
os
.
getenv
(
"NVTE_PYTORCH_SKIP_CUDA_BUILD"
,
"FALSE"
)
==
"TRUE"
PACKAGE_NAME
=
"transformer_engine_torch"
BASE_WHEEL_URL
=
(
"https://github.com/NVIDIA/TransformerEngine/releases/download/{tag_name}/{wheel_name}"
)
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
# torch._C._GLIBCXX_USE_CXX11_ABI
# https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
if
FORCE_CXX11_ABI
:
torch
.
_C
.
_GLIBCXX_USE_CXX11_ABI
=
True
current_file_path
=
Path
(
__file__
).
parent
.
resolve
()
build_tools_dir
=
current_file_path
.
parent
.
parent
/
"build_tools"
...
...
@@ -31,13 +47,94 @@ if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_
from
build_tools.build_ext
import
get_build_ext
from
build_tools.utils
import
copy_common_headers
from
build_tools.te_version
import
te_version
from
build_tools.pytorch
import
setup_pytorch_extension
,
install_requirements
,
test_requirements
from
build_tools.pytorch
import
(
setup_pytorch_extension
,
install_requirements
,
test_requirements
,
)
os
.
environ
[
"NVTE_PROJECT_BUILDING"
]
=
"1"
CMakeBuildExtension
=
get_build_ext
(
BuildExtension
,
True
)
def
get_platform
():
"""
Returns the platform name as used in wheel filenames.
"""
if
sys
.
platform
.
startswith
(
"linux"
):
return
f
"linux_
{
platform
.
uname
().
machine
}
"
if
sys
.
platform
==
"darwin"
:
mac_version
=
"."
.
join
(
platform
.
mac_ver
()[
0
].
split
(
"."
)[:
2
])
return
f
"macosx_
{
mac_version
}
_x86_64"
if
sys
.
platform
==
"win32"
:
return
"win_amd64"
raise
ValueError
(
f
"Unsupported platform:
{
sys
.
platform
}
"
)
def
get_wheel_url
():
"""Construct the wheel URL for the current platform."""
torch_version_raw
=
parse
(
torch
.
__version__
)
python_version
=
f
"cp
{
sys
.
version_info
.
major
}{
sys
.
version_info
.
minor
}
"
platform_name
=
get_platform
()
nvte_version
=
te_version
()
torch_version
=
f
"
{
torch_version_raw
.
major
}
.
{
torch_version_raw
.
minor
}
"
cxx11_abi
=
str
(
torch
.
_C
.
_GLIBCXX_USE_CXX11_ABI
).
upper
()
# Determine the version numbers that will be used to determine the correct wheel
# We're using the CUDA version used to build torch, not the one currently installed
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_cuda_version
=
parse
(
torch
.
version
.
cuda
)
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
# to save CI time. Minor versions should be compatible.
torch_cuda_version
=
parse
(
"11.8"
)
if
torch_cuda_version
.
major
==
11
else
parse
(
"12.3"
)
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
cuda_version
=
f
"
{
torch_cuda_version
.
major
}
"
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename
=
f
"
{
PACKAGE_NAME
}
-
{
nvte_version
}
+cu
{
cuda_version
}
torch
{
torch_version
}
cxx11abi
{
cxx11_abi
}
-
{
python_version
}
-
{
python_version
}
-
{
platform_name
}
.whl"
wheel_url
=
BASE_WHEEL_URL
.
format
(
tag_name
=
f
"v
{
nvte_version
}
"
,
wheel_name
=
wheel_filename
)
return
wheel_url
,
wheel_filename
class
CachedWheelsCommand
(
_bdist_wheel
):
"""
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
find an existing wheel (which is currently the case for all grouped gemm installs). We use
the environment parameters to detect whether there is already a pre-built version of a compatible
wheel available and short-circuits the standard full build pipeline.
"""
def
run
(
self
):
if
FORCE_BUILD
:
super
().
run
()
wheel_url
,
wheel_filename
=
get_wheel_url
()
print
(
"Guessing wheel URL: "
,
wheel_url
)
try
:
urllib
.
request
.
urlretrieve
(
wheel_url
,
wheel_filename
)
# Make the archive
# Lifted from the root wheel processing command
# https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
if
not
os
.
path
.
exists
(
self
.
dist_dir
):
os
.
makedirs
(
self
.
dist_dir
)
impl_tag
,
abi_tag
,
plat_tag
=
self
.
get_tag
()
archive_basename
=
f
"
{
self
.
wheel_dist_name
}
-
{
impl_tag
}
-
{
abi_tag
}
-
{
plat_tag
}
"
wheel_path
=
os
.
path
.
join
(
self
.
dist_dir
,
archive_basename
+
".whl"
)
print
(
"Raw wheel path"
,
wheel_path
)
os
.
rename
(
wheel_filename
,
wheel_path
)
except
(
urllib
.
error
.
HTTPError
,
urllib
.
error
.
URLError
):
print
(
"Precompiled wheel not found. Building from source..."
)
# If the wheel could not be downloaded, build from source
super
().
run
()
if
__name__
==
"__main__"
:
# Extensions
common_headers_dir
=
"common_headers"
...
...
@@ -50,11 +147,11 @@ if __name__ == "__main__":
# Configure package
setuptools
.
setup
(
name
=
"transformer_engine_torch"
,
name
=
PACKAGE_NAME
,
version
=
te_version
(),
description
=
"Transformer acceleration library - Torch Lib"
,
ext_modules
=
ext_modules
,
cmdclass
=
{
"build_ext"
:
CMakeBuildExtension
},
cmdclass
=
{
"build_ext"
:
CMakeBuildExtension
,
"bdist_wheel"
:
CachedWheelsCommand
},
install_requires
=
install_requirements
(),
tests_require
=
test_requirements
(),
)
...
...
transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py
View file @
27ddce40
...
...
@@ -350,9 +350,14 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
def
_transpose_columnwise_data
(
self
):
"""Plainly transpose the columnwise data and scale inv."""
if
self
.
_columnwise_data
is
not
None
:
# TODO(yuzhongw, tmoon): Figure out why _old_data is not automatically
# deallocated by GC. Manually deallocating is a temporary hack.
_old_data
=
self
.
_columnwise_data
self
.
_columnwise_data
=
tex
.
fp8_transpose
(
self
.
_columnwise_data
,
self
.
_fp8_dtype
,
out
=
None
)
_old_data
.
data
=
_empty_tensor
()
del
_old_data
def
__repr__
(
self
):
if
self
.
_rowwise_data
is
not
None
:
...
...
transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py
View file @
27ddce40
...
...
@@ -95,8 +95,13 @@ class Float8TensorBase(QuantizedTensorBase):
return
instance
def
clear
(
self
):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
for
t
in
(
self
.
_data
,
self
.
_transpose
,
self
.
_scale_inv
):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully.
Scale-inv tensor is not deallocated because it's often shared
between multiple FP8 tensors.
"""
for
t
in
(
self
.
_data
,
self
.
_transpose
):
if
t
is
not
None
:
t
.
data
=
_empty_tensor
()
self
.
_transpose_invalid
=
True
...
...
transformer_engine/pytorch/tensor/float8_tensor.py
View file @
27ddce40
...
...
@@ -178,7 +178,7 @@ class Float8Quantizer(Quantizer):
def
onnx_dequantize
(
self
,
tensor
:
QuantizedTensor
)
->
torch
.
Tensor
:
"""Function using primitives with ONNX defined translations."""
out
=
torch
.
ops
.
tex
.
fp8_dequantize
(
tensor
.
_data
,
self
.
scale
.
item
()
)
out
=
torch
.
ops
.
tex
.
fp8_dequantize
(
tensor
.
_data
,
tensor
.
_scale_inv
)
out
=
out
.
to
(
tensor
.
dtype
)
return
out
...
...
@@ -351,15 +351,25 @@ class Float8CurrentScalingQuantizer(Quantizer):
def
onnx_quantize
(
self
,
tensor
:
torch
.
Tensor
)
->
QuantizedTensor
:
"""Function using primitives with ONNX defined translations."""
raise
NotImplementedError
(
"Float8CurrentScalingQuantizer does not support ONNX quantization yet."
if
tensor
.
dtype
!=
torch
.
float32
:
tensor
=
tensor
.
to
(
torch
.
float32
)
data
,
scale_inv
=
torch
.
ops
.
tex
.
fp8_cs_quantize
(
tensor
)
return
Float8Tensor
(
shape
=
data
.
shape
,
dtype
=
torch
.
float32
,
data
=
data
,
fp8_scale_inv
=
scale_inv
,
fp8_dtype
=
self
.
dtype
,
requires_grad
=
False
,
data_transpose
=
None
,
quantizer
=
self
,
)
def
onnx_dequantize
(
self
,
tensor
:
QuantizedTensor
)
->
torch
.
Tensor
:
"""Function using primitives with ONNX defined translations."""
raise
NotImplementedError
(
"Float8CurrentScalingQuantizer does not support ONNX dequantization yet."
)
out
=
torch
.
ops
.
tex
.
fp8_dequantize
(
tensor
.
_data
,
tensor
.
_scale_inv
)
out
=
out
.
to
(
tensor
.
dtype
)
return
out
def
_canonicalized_amax_reduction_group
(
self
)
->
dist_group_type
:
"""Get process group for amax reduction"""
...
...
transformer_engine/pytorch/transformer.py
View file @
27ddce40
...
...
@@ -175,7 +175,8 @@ class TransformerLayer(torch.nn.Module):
if set to `False`, the transformer layer will not learn any additive biases.
activation : str, default = 'gelu'
Type of activation used in MLP block.
Options are: 'gelu', 'relu', 'reglu', 'geglu', 'swiglu', 'qgelu' and 'srelu'.
Options are: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu',
'silu', and 'swiglu'.
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
...
...
transformer_engine/pytorch/triton/cross_entropy.py
View file @
27ddce40
...
...
@@ -231,6 +231,7 @@ def element_mul_kernel(
X_ptr
,
X_stride
,
grad_output_ptr
,
grad_output_stride
,
n_cols
,
BLOCK_SIZE
:
tl
.
constexpr
,
):
...
...
@@ -253,6 +254,7 @@ def element_mul_kernel(
X_ptr
+=
program_id
*
X_stride
# Load the gradient output value
grad_output_ptr
+=
program_id
*
grad_output_stride
grad_output
=
tl
.
load
(
grad_output_ptr
)
# Perform the element-wise multiplication
...
...
@@ -361,6 +363,7 @@ def cross_entropy_backward(
_input
,
_input
.
stride
(
-
2
),
grad_output
,
1
if
grad_output
.
numel
()
>
1
else
0
,
V
,
BLOCK_SIZE
=
BLOCK_SIZE
,
num_warps
=
16
if
IS_HIP_EXTENSION
else
32
,
...
...
Prev
1
…
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment