Commit 27ddce40 authored by wenjh's avatar wenjh
Browse files

Merge branch 'nv_main'

parents d262ef4c 5b3092a0
......@@ -19,6 +19,7 @@ from transformer_engine.pytorch.ops.op import (
)
from transformer_engine.pytorch.ops.fused import (
fuse_backward_activation_bias,
fuse_backward_add_rmsnorm,
fuse_backward_linear_add,
fuse_backward_linear_scale,
fuse_forward_linear_bias_activation,
......@@ -371,6 +372,7 @@ class OperationFuser:
ops = fuse_backward_linear_add(ops)
ops = fuse_backward_linear_scale(ops)
ops = fuse_backward_activation_bias(ops, recipe)
ops = fuse_backward_add_rmsnorm(ops)
return ops
def maybe_fuse_ops(
......
......@@ -54,7 +54,8 @@ class Linear(FusedOperation):
weight's `main_grad` attribute instead of relying on PyTorch
autograd. The weight's `main_grad` must be set externally and
there is no guarantee that `grad` will be set or be
meaningful.
meaningful. This is primarily intented to integrate with
Megatron-LM.
"""
......
......@@ -10,14 +10,30 @@ import sys
import os
import shutil
from pathlib import Path
import platform
import urllib
import setuptools
from wheel.bdist_wheel import bdist_wheel as _bdist_wheel
from packaging.version import parse
try:
import torch
from torch.utils.cpp_extension import BuildExtension
except ImportError as e:
raise RuntimeError("This package needs Torch to build.") from e
FORCE_BUILD = os.getenv("NVTE_PYTORCH_FORCE_BUILD", "FALSE") == "TRUE"
FORCE_CXX11_ABI = os.getenv("NVTE_PYTORCH_FORCE_CXX11_ABI", "FALSE") == "TRUE"
SKIP_CUDA_BUILD = os.getenv("NVTE_PYTORCH_SKIP_CUDA_BUILD", "FALSE") == "TRUE"
PACKAGE_NAME = "transformer_engine_torch"
BASE_WHEEL_URL = (
"https://github.com/NVIDIA/TransformerEngine/releases/download/{tag_name}/{wheel_name}"
)
# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as
# torch._C._GLIBCXX_USE_CXX11_ABI
# https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920
if FORCE_CXX11_ABI:
torch._C._GLIBCXX_USE_CXX11_ABI = True
current_file_path = Path(__file__).parent.resolve()
build_tools_dir = current_file_path.parent.parent / "build_tools"
......@@ -31,13 +47,94 @@ if bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))) or os.path.isdir(build_tools_
from build_tools.build_ext import get_build_ext
from build_tools.utils import copy_common_headers
from build_tools.te_version import te_version
from build_tools.pytorch import setup_pytorch_extension, install_requirements, test_requirements
from build_tools.pytorch import (
setup_pytorch_extension,
install_requirements,
test_requirements,
)
os.environ["NVTE_PROJECT_BUILDING"] = "1"
CMakeBuildExtension = get_build_ext(BuildExtension, True)
def get_platform():
"""
Returns the platform name as used in wheel filenames.
"""
if sys.platform.startswith("linux"):
return f"linux_{platform.uname().machine}"
if sys.platform == "darwin":
mac_version = ".".join(platform.mac_ver()[0].split(".")[:2])
return f"macosx_{mac_version}_x86_64"
if sys.platform == "win32":
return "win_amd64"
raise ValueError(f"Unsupported platform: {sys.platform}")
def get_wheel_url():
"""Construct the wheel URL for the current platform."""
torch_version_raw = parse(torch.__version__)
python_version = f"cp{sys.version_info.major}{sys.version_info.minor}"
platform_name = get_platform()
nvte_version = te_version()
torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}"
cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper()
# Determine the version numbers that will be used to determine the correct wheel
# We're using the CUDA version used to build torch, not the one currently installed
# _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME)
torch_cuda_version = parse(torch.version.cuda)
# For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3
# to save CI time. Minor versions should be compatible.
torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3")
# cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}"
cuda_version = f"{torch_cuda_version.major}"
# Determine wheel URL based on CUDA version, torch version, python version and OS
wheel_filename = f"{PACKAGE_NAME}-{nvte_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl"
wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{nvte_version}", wheel_name=wheel_filename)
return wheel_url, wheel_filename
class CachedWheelsCommand(_bdist_wheel):
"""
The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot
find an existing wheel (which is currently the case for all grouped gemm installs). We use
the environment parameters to detect whether there is already a pre-built version of a compatible
wheel available and short-circuits the standard full build pipeline.
"""
def run(self):
if FORCE_BUILD:
super().run()
wheel_url, wheel_filename = get_wheel_url()
print("Guessing wheel URL: ", wheel_url)
try:
urllib.request.urlretrieve(wheel_url, wheel_filename)
# Make the archive
# Lifted from the root wheel processing command
# https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85
if not os.path.exists(self.dist_dir):
os.makedirs(self.dist_dir)
impl_tag, abi_tag, plat_tag = self.get_tag()
archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}"
wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl")
print("Raw wheel path", wheel_path)
os.rename(wheel_filename, wheel_path)
except (urllib.error.HTTPError, urllib.error.URLError):
print("Precompiled wheel not found. Building from source...")
# If the wheel could not be downloaded, build from source
super().run()
if __name__ == "__main__":
# Extensions
common_headers_dir = "common_headers"
......@@ -50,11 +147,11 @@ if __name__ == "__main__":
# Configure package
setuptools.setup(
name="transformer_engine_torch",
name=PACKAGE_NAME,
version=te_version(),
description="Transformer acceleration library - Torch Lib",
ext_modules=ext_modules,
cmdclass={"build_ext": CMakeBuildExtension},
cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": CachedWheelsCommand},
install_requires=install_requirements(),
tests_require=test_requirements(),
)
......
......@@ -350,9 +350,14 @@ class Float8BlockwiseQTensorBase(QuantizedTensorBase):
def _transpose_columnwise_data(self):
"""Plainly transpose the columnwise data and scale inv."""
if self._columnwise_data is not None:
# TODO(yuzhongw, tmoon): Figure out why _old_data is not automatically
# deallocated by GC. Manually deallocating is a temporary hack.
_old_data = self._columnwise_data
self._columnwise_data = tex.fp8_transpose(
self._columnwise_data, self._fp8_dtype, out=None
)
_old_data.data = _empty_tensor()
del _old_data
def __repr__(self):
if self._rowwise_data is not None:
......
......@@ -95,8 +95,13 @@ class Float8TensorBase(QuantizedTensorBase):
return instance
def clear(self):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully."""
for t in (self._data, self._transpose, self._scale_inv):
"""Deallocate this tensor's memory. Typically not needed and must be used carefully.
Scale-inv tensor is not deallocated because it's often shared
between multiple FP8 tensors.
"""
for t in (self._data, self._transpose):
if t is not None:
t.data = _empty_tensor()
self._transpose_invalid = True
......
......@@ -178,7 +178,7 @@ class Float8Quantizer(Quantizer):
def onnx_dequantize(self, tensor: QuantizedTensor) -> torch.Tensor:
"""Function using primitives with ONNX defined translations."""
out = torch.ops.tex.fp8_dequantize(tensor._data, self.scale.item())
out = torch.ops.tex.fp8_dequantize(tensor._data, tensor._scale_inv)
out = out.to(tensor.dtype)
return out
......@@ -351,15 +351,25 @@ class Float8CurrentScalingQuantizer(Quantizer):
def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor:
"""Function using primitives with ONNX defined translations."""
raise NotImplementedError(
"Float8CurrentScalingQuantizer does not support ONNX quantization yet."
if tensor.dtype != torch.float32:
tensor = tensor.to(torch.float32)
data, scale_inv = torch.ops.tex.fp8_cs_quantize(tensor)
return Float8Tensor(
shape=data.shape,
dtype=torch.float32,
data=data,
fp8_scale_inv=scale_inv,
fp8_dtype=self.dtype,
requires_grad=False,
data_transpose=None,
quantizer=self,
)
def onnx_dequantize(self, tensor: QuantizedTensor) -> torch.Tensor:
"""Function using primitives with ONNX defined translations."""
raise NotImplementedError(
"Float8CurrentScalingQuantizer does not support ONNX dequantization yet."
)
out = torch.ops.tex.fp8_dequantize(tensor._data, tensor._scale_inv)
out = out.to(tensor.dtype)
return out
def _canonicalized_amax_reduction_group(self) -> dist_group_type:
"""Get process group for amax reduction"""
......
......@@ -175,7 +175,8 @@ class TransformerLayer(torch.nn.Module):
if set to `False`, the transformer layer will not learn any additive biases.
activation : str, default = 'gelu'
Type of activation used in MLP block.
Options are: 'gelu', 'relu', 'reglu', 'geglu', 'swiglu', 'qgelu' and 'srelu'.
Options are: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu',
'silu', and 'swiglu'.
device : Union[torch.device, str], default = "cuda"
The device on which the parameters of the model will be allocated. It is the user's
responsibility to ensure all parameters are moved to the GPU before running the
......
......@@ -231,6 +231,7 @@ def element_mul_kernel(
X_ptr,
X_stride,
grad_output_ptr,
grad_output_stride,
n_cols,
BLOCK_SIZE: tl.constexpr,
):
......@@ -253,6 +254,7 @@ def element_mul_kernel(
X_ptr += program_id * X_stride
# Load the gradient output value
grad_output_ptr += program_id * grad_output_stride
grad_output = tl.load(grad_output_ptr)
# Perform the element-wise multiplication
......@@ -361,6 +363,7 @@ def cross_entropy_backward(
_input,
_input.stride(-2),
grad_output,
1 if grad_output.numel() > 1 else 0,
V,
BLOCK_SIZE=BLOCK_SIZE,
num_warps=16 if IS_HIP_EXTENSION else 32,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment