Unverified Commit 0587ecf4 authored by Jan Bielak's avatar Jan Bielak Committed by GitHub
Browse files

Optimize reshaping tensors in the `te.ops.Sequential` implementation (#1876)



* Optimize _common.reshape by removing redundant operations
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Use view instead of reshape when possible
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Simplify convert_tensor (requires testing)
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Remove reshape
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Refactor existing code to use maybe_quantize where possible
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Check if tensor is any kind of quantized tensor
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

* Revert "Check if tensor is any kind of quantized tensor"

This reverts commit cf09d61ffe41f38720d820ddc4f011f9dc1fb56e.
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>

---------
Signed-off-by: default avatarJan Bielak <jbielak@nvidia.com>
Co-authored-by: default avatarTim Moon <4406448+timmoon10@users.noreply.github.com>
parent 1d1d3233
......@@ -5,7 +5,7 @@
"""Helper functions used in fusible operations."""
from __future__ import annotations
from typing import Any, Iterable, Optional
from typing import Any, Optional
import torch
......@@ -13,11 +13,8 @@ from transformer_engine_torch import FP8TensorMeta
from .. import torch_version
from ..fp8 import FP8GlobalStateManager
from ..tensor.float8_tensor import Float8Tensor
from ..utils import (
canonicalize_device,
canonicalize_dtype,
devices_match,
)
from ..tensor.quantized_tensor import QuantizedTensorBase
from ..utils import canonicalize_dtype
def is_float8_tensor(tensor: Any) -> bool:
......@@ -25,74 +22,17 @@ def is_float8_tensor(tensor: Any) -> bool:
return isinstance(tensor, Float8Tensor)
def convert_tensor(
tensor: torch.Tensor | Float8Tensor,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
memory_format: torch.memory_format = torch.preserve_format,
) -> torch.Tensor | Float8Tensor:
"""Convert tensor attributes, keeping same data if possible"""
# Default kwargs
if device is None:
device = tensor.device
device = canonicalize_device(device)
if dtype is None:
dtype = tensor.dtype
dtype = canonicalize_dtype(dtype)
# Make sure output is detached from autograd graph
tensor = tensor.detach()
# Return immediately if tensor already has desired attributes
if devices_match(device, tensor.device) and dtype == tensor.dtype:
if memory_format == torch.preserve_format or tensor.is_contiguous(
memory_format=memory_format
):
return tensor
# Convert FP8 tensor
if is_float8_tensor(tensor):
data = tensor._data
if not devices_match(device, data.device):
data = data.to(device=device)
if memory_format != torch.preserve_format and not data.is_contiguous(
memory_format=memory_format
):
# Note: torch.Tensor.to ignores memory_format kwarg (see
# https://github.com/pytorch/pytorch/issues/132020).
data = data.contiguous(memory_format=memory_format)
out = Float8Tensor.make_like(tensor, dtype=dtype)
out.data = data
return out
# Convert standard PyTorch tensor
tensor = tensor.to(device=device, dtype=dtype)
if memory_format != torch.preserve_format and not tensor.is_contiguous(
memory_format=memory_format
):
# Note: torch.Tensor.to ignores memory_format kwarg (see
# https://github.com/pytorch/pytorch/issues/132020).
tensor = tensor.contiguous(memory_format=memory_format)
def maybe_dequantize(
tensor: torch.Tensor | QuantizedTensorBase, dtype: torch.dtype | None = None
) -> torch.Tensor:
"""Dequantize tensor to given dtype or just convert if not a quantized tensor"""
if isinstance(tensor, QuantizedTensorBase):
return tensor.dequantize(dtype=dtype)
if dtype is not None and tensor.dtype != dtype:
return tensor.to(dtype)
return tensor
def reshape(
tensor: torch.Tensor | Float8Tensor,
shape: Iterable[int],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor | Float8Tensor:
"""Reshape tensor, keeping same data if possible"""
tensor = convert_tensor(
tensor,
device=device,
dtype=dtype,
memory_format=torch.contiguous_format,
)
return tensor.reshape(*shape)
def maybe_autocast_dtype(
*,
device_type: str = "cuda",
......
......@@ -12,11 +12,10 @@ import torch
import transformer_engine_torch as tex
from ...fp8 import FP8GlobalStateManager
from ...tensor import QuantizedTensor
from ...tensor.float8_tensor import Float8CurrentScalingQuantizer
from ...utils import clear_tensor_data, devices_match
from ...utils import clear_tensor_data
from ..op import BasicOperation, OperationContext
from .._common import reshape
from .._common import maybe_dequantize
class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
......@@ -86,15 +85,7 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
raise RuntimeError(f"Unsupported dtype ({dtype})")
# Check input tensor
x = input_
if isinstance(x, QuantizedTensor):
x = x.dequantize()
if x.device.type != "cuda":
x = x.cuda()
if x.dtype != dtype:
x = x.to(dtype=dtype)
if not x.is_contiguous():
x = x.contiguous()
x = maybe_dequantize(input_.contiguous(), dtype)
# Check if quantized compute is enabled
quantized_compute_enabled = FP8GlobalStateManager.is_fp8_enabled()
......@@ -108,13 +99,13 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
# Launch kernel
y = self._activation_forward_impl(
reshape(x, (-1, x.size(-1))),
x.view((-1, x.size(-1))),
quantizer,
)
# Check output tensor
if y.dim() != x.dim():
y = y.reshape(list(x.shape[:-1]) + [-1])
y = y.view(list(x.shape[:-1]) + [-1])
# Quantize input to FP8 before caching if needed
if self.cache_quantized_input:
......@@ -140,21 +131,10 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
(x,) = ctx.saved_tensors
# Check input tensor
if isinstance(x, QuantizedTensor):
x = x.dequantize(dtype=ctx.dtype)
elif x.dtype != ctx.dtype:
x = x.to(dtype=ctx.dtype)
if not x.is_contiguous():
x = x.contiguous()
x = maybe_dequantize(x.contiguous(), ctx.dtype)
# Check grad output tensor
dy = grad_output
if isinstance(dy, QuantizedTensor):
dy = dy.dequantize(dtype=ctx.dtype)
if not devices_match(dy.device, x.device) or dy.dtype != x.dtype:
dy = dy.to(device=x.device, dtype=x.dtype)
if not dy.is_contiguous():
dy = dy.contiguous()
dy = maybe_dequantize(grad_output.contiguous(), x.dtype)
# Check if quantized compute is enabled
quantizer = None
......@@ -167,14 +147,14 @@ class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta):
# Launch kernel
dx = self._activation_backward_impl(
reshape(dy, (-1, dy.size(-1))),
reshape(x, (-1, x.size(-1))),
dy.view((-1, dy.size(-1))),
x.view((-1, x.size(-1))),
quantizer,
)
# Check grad input tensor
if dx.size() != x.size():
dx = dx.reshape(x.size())
dx = dx.view(x.size())
# Clear input tensor if possible
if ctx.prev_op is not None:
......
......@@ -10,7 +10,7 @@ from typing import Optional
import torch
from ...distributed import gather_along_first_dim
from ...tensor import QuantizedTensor
from .._common import maybe_dequantize
from ..op import BasicOperation, OperationContext
......@@ -71,10 +71,7 @@ class AllGather(BasicOperation):
input_dims[0] //= self.process_group_size
# Check output gradient tensor
dy = grad_output
if isinstance(dy, QuantizedTensor):
dy = dy.dequantize()
dy = dy.contiguous()
dy = maybe_dequantize(grad_output.contiguous())
# Perform reduce-scatter
dx = torch.empty(input_dims, dtype=dy.dtype, device=dy.device)
......
......@@ -9,7 +9,7 @@ from typing import Optional
import torch
from ...tensor import QuantizedTensor
from .._common import maybe_dequantize
from ..op import BasicOperation, OperationContext
......@@ -50,10 +50,7 @@ class AllReduce(BasicOperation):
return input_
# Perform all-reduce
x = input_
if isinstance(x, QuantizedTensor):
x = x.dequantize()
x = x.contiguous()
x = maybe_dequantize(input_.contiguous())
torch.distributed.all_reduce(x, group=self.process_group)
return x
......
......@@ -27,12 +27,13 @@ from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ..op import BasicOperation, OperationContext
from .._common import (
from .._common import maybe_dequantize
from ...utils import (
canonicalize_device,
canonicalize_dtype,
clear_tensor_data,
devices_match,
)
from ...utils import clear_tensor_data
def _wait_async(handle: Optional[Any]) -> None:
......@@ -466,10 +467,8 @@ class BasicLinear(BasicOperation):
x_local = input_quantizer(x_local)
x = x_local
else:
if isinstance(x_local, QuantizedTensor):
x_local = x_local.dequantize()
if x_local.dtype != dtype:
x_local = x_local.to(dtype=dtype)
x_local = maybe_dequantize(x_local, dtype)
if with_x_all_gather:
x, x_async = gather_along_first_dim(
x_local,
......@@ -481,16 +480,13 @@ class BasicLinear(BasicOperation):
# Check weight tensor
w = weight
w_is_quantized = isinstance(w, QuantizedTensor)
if with_quantized_compute and not w_is_quantized:
if not with_quantized_compute:
w = maybe_dequantize(w, dtype)
elif with_quantized_compute and not isinstance(w, QuantizedTensor):
if weight_quantizer is None:
raise ValueError("Missing quantizer for weight tensor")
weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad)
w = weight_quantizer(w)
elif not with_quantized_compute and w_is_quantized:
w = w.dequantize()
if not with_quantized_compute and w.dtype != dtype:
w = w.to(dtype=dtype)
# Check output tensor
y = out
......@@ -700,10 +696,8 @@ class BasicLinear(BasicOperation):
dy_local = grad_output_quantizer(dy_local)
dy = dy_local
else:
if isinstance(dy_local, QuantizedTensor):
dy_local = dy_local.dequantize()
if dy_local.dtype != dtype:
dy_local = dy_local.to(dtype=dtype)
dy_local = maybe_dequantize(dy_local, dtype)
if with_dy_all_gather:
dy, dy_async = gather_along_first_dim(
dy_local,
......@@ -739,10 +733,8 @@ class BasicLinear(BasicOperation):
x_local = input_quantizer(x_local)
x = x_local
else:
if isinstance(x_local, QuantizedTensor):
x_local = x_local.dequantize()
if x_local.dtype != dtype:
x_local = x_local.to(dtype=dtype)
x_local = maybe_dequantize(x_local, dtype)
if with_x_all_gather:
x, x_async = gather_along_first_dim(
x_local,
......@@ -761,9 +753,8 @@ class BasicLinear(BasicOperation):
if weight is None:
raise ValueError("Weight tensor is required to compute input grad")
w = weight
w_is_quantized = isinstance(w, QuantizedTensor)
if with_quantized_compute:
if w_is_quantized:
if isinstance(w, QuantizedTensor):
w.update_usage(columnwise_usage=True)
else:
if weight_quantizer is None:
......@@ -771,10 +762,7 @@ class BasicLinear(BasicOperation):
weight_quantizer.set_usage(columnwise=True)
w = weight_quantizer(w)
else:
if w_is_quantized:
w = w.dequantize(dtype=dtype)
elif w.dtype != dtype:
w = w.to(dtype=dtype)
w = maybe_dequantize(w, dtype)
# Synchronize tensor-parallel communication
_wait_async(dy_async)
......
......@@ -13,7 +13,7 @@ from transformer_engine.pytorch.ops.op import (
BasicOperation,
OperationContext,
)
from .._common import (
from ...utils import (
canonicalize_device,
canonicalize_dtype,
)
......@@ -124,7 +124,7 @@ class Bias(BasicOperation):
next_op: Optional[BasicOperation] = None,
) -> torch.Tensor:
x = input_
b = self.bias.reshape([1] * (x.dim() - 1) + [self.local_size])
b = self.bias.view([1] * (x.dim() - 1) + [self.local_size])
return x + b
def op_backward(
......
......@@ -9,8 +9,8 @@ from typing import Optional
import torch
from ...tensor import QuantizedTensor
from ...utils import clear_tensor_data
from .._common import maybe_dequantize
from ..op import BasicOperation, OperationContext
from ...jit import (
l2normalization_fused,
......@@ -77,10 +77,7 @@ class L2Normalization(BasicOperation):
next_op: Optional[BasicOperation] = None,
) -> torch.Tensor:
# Use input directly - torch.compile can handle multi-dimensional tensors
x = input_
if isinstance(x, QuantizedTensor):
x = x.dequantize()
x = maybe_dequantize(input_)
# Check if backward pass is needed
requires_grad = ctx.requires_grad
......@@ -111,10 +108,7 @@ class L2Normalization(BasicOperation):
# Saved tensors from forward pass
x, rsqrt_norm = ctx.saved_tensors
dy = grad_output
if isinstance(dy, QuantizedTensor):
dy = dy.dequantize()
dy = maybe_dequantize(grad_output)
# Compute L2 norm backward pass using fused implementation
dx = l2normalization_backward_fused(dy, x, rsqrt_norm, self.eps)
......
......@@ -15,7 +15,6 @@ import torch
from transformer_engine_torch import layernorm_bwd, layernorm_fwd
from ...fp8 import FP8GlobalStateManager
from ...constants import TE_DType
from ...tensor import QuantizedTensor
from ...utils import (
canonicalize_device,
canonicalize_dtype,
......@@ -23,7 +22,7 @@ from ...utils import (
devices_match,
)
from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, reshape
from .._common import maybe_autocast_dtype, maybe_dequantize
class LayerNorm(BasicOperation):
......@@ -192,19 +191,10 @@ class LayerNorm(BasicOperation):
# Check input tensors
inner_dim = math.prod(weight_dims)
device = weight.device
if device.type != "cuda":
device = canonicalize_device(None)
dtype = maybe_autocast_dtype(default_dtype=weight.dtype)
x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
b = reshape(self.bias, (inner_dim,), device=device, dtype=dtype)
if isinstance(x, QuantizedTensor):
x = x.dequantize()
if isinstance(w, QuantizedTensor):
w = w.dequantize()
if isinstance(b, QuantizedTensor):
b = b.dequantize()
x = maybe_dequantize(input_.contiguous(), dtype).view((-1, inner_dim))
w = maybe_dequantize(self.weight, dtype).view((inner_dim,))
b = maybe_dequantize(self.bias, dtype).view((inner_dim,))
# Check if backward pass is needed
requires_grad = ctx.requires_grad
......@@ -235,12 +225,11 @@ class LayerNorm(BasicOperation):
# Save state for backward pass
if requires_grad:
ctx.save_for_backward(x, means, rstdevs)
ctx.device = device
ctx.dtype = dtype
ctx.has_prev_op = prev_op is not None
# Reshape output tensor
out = reshape(y, input_dims)
out = y.view(input_dims)
return out
def op_backward(
......@@ -257,14 +246,9 @@ class LayerNorm(BasicOperation):
inner_dim = math.prod(weight_dims)
# Check input tensors
device = ctx.device
dtype = ctx.dtype
dy = reshape(grad_output, x.size(), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
if isinstance(w, QuantizedTensor):
w = w.dequantize()
if isinstance(dy, QuantizedTensor):
dy = dy.dequantize()
dy = maybe_dequantize(grad_output.contiguous(), dtype).view(x.size())
w = maybe_dequantize(self.weight, dtype).view((inner_dim,))
# Compute layer norm backward pass
dx, dw, db = layernorm_bwd(
......@@ -284,7 +268,7 @@ class LayerNorm(BasicOperation):
clear_tensor_data(rstdevs)
# Reshape results
grad_input = reshape(dx, grad_output.size())
grad_weight = reshape(dw, weight_dims)
grad_bias = reshape(db, weight_dims)
grad_input = dx.view(grad_output.size())
grad_weight = dw.view(weight_dims)
grad_bias = db.view(weight_dims)
return grad_input, (grad_weight, grad_bias)
......@@ -10,7 +10,7 @@ from typing import Optional
import torch
from ...distributed import gather_along_first_dim
from ...tensor import QuantizedTensor
from .._common import maybe_dequantize
from ..op import BasicOperation, OperationContext
......@@ -59,10 +59,7 @@ class ReduceScatter(BasicOperation):
output_dims[0] //= self.process_group_size
# Check input tensor
x = input_
if isinstance(x, QuantizedTensor):
x = x.dequantize()
x = x.contiguous()
x = maybe_dequantize(input_.contiguous())
# Perform reduce-scatter
y = torch.empty(output_dims, dtype=x.dtype, device=x.device)
......
......@@ -14,7 +14,6 @@ import torch
from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd
from ...fp8 import FP8GlobalStateManager
from ...tensor import QuantizedTensor
from ...constants import TE_DType
from ...utils import (
canonicalize_device,
......@@ -23,7 +22,7 @@ from ...utils import (
devices_match,
)
from ..op import BasicOperation, OperationContext
from .._common import maybe_autocast_dtype, reshape
from .._common import maybe_autocast_dtype, maybe_dequantize
class RMSNorm(BasicOperation):
......@@ -175,16 +174,9 @@ class RMSNorm(BasicOperation):
# Check input tensors
inner_dim = math.prod(weight_dims)
device = weight.device
if device.type != "cuda":
device = canonicalize_device(None)
dtype = maybe_autocast_dtype(default_dtype=weight.dtype)
x = reshape(input_, (-1, inner_dim), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
if isinstance(x, QuantizedTensor):
x = x.dequantize()
if isinstance(w, QuantizedTensor):
w = w.dequantize()
x = maybe_dequantize(input_.contiguous(), dtype).view((-1, inner_dim))
w = maybe_dequantize(self.weight, dtype).view((inner_dim,))
# Check if backward pass is needed
requires_grad = ctx.requires_grad
......@@ -214,12 +206,11 @@ class RMSNorm(BasicOperation):
# Save state for backward pass
if requires_grad:
ctx.save_for_backward(x, rstdevs)
ctx.device = device
ctx.dtype = dtype
ctx.has_prev_op = prev_op is not None
# Reshape output tensor
out = reshape(y, input_dims)
out = y.view(input_dims)
return out
def op_backward(
......@@ -236,14 +227,9 @@ class RMSNorm(BasicOperation):
inner_dim = math.prod(weight_dims)
# Check input tensors
device = ctx.device
dtype = ctx.dtype
dy = reshape(grad_output, x.size(), device=device, dtype=dtype)
w = reshape(self.weight, (inner_dim,), device=device, dtype=dtype)
if isinstance(w, QuantizedTensor):
w = w.dequantize()
if isinstance(dy, QuantizedTensor):
dy = dy.dequantize()
dy = maybe_dequantize(grad_output.contiguous(), dtype).view(x.size())
w = maybe_dequantize(self.weight, dtype).view((inner_dim,))
# Compute RMSNorm backward pass
dx, dw = rmsnorm_bwd(
......@@ -261,6 +247,6 @@ class RMSNorm(BasicOperation):
clear_tensor_data(rstdevs)
# Reshape results
grad_input = reshape(dx, grad_output.size())
grad_weight = reshape(dw, weight_dims)
grad_input = dx.view(grad_output.size())
grad_weight = dw.view(weight_dims)
return grad_input, (grad_weight,)
......@@ -24,6 +24,7 @@ from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer
from ...tensor.mxfp8_tensor import MXFP8Quantizer
from ...utils import canonicalize_device, canonicalize_dtype, clear_tensor_data
from ..basic import BasicLinear, Bias, ReduceScatter
from .._common import maybe_dequantize
from ..op import FusedOperation, FusibleOperation, OperationContext
......@@ -293,10 +294,7 @@ class UserbuffersBackwardLinear(FusedOperation):
)
dy_local = grad_output_quantizer(dy_local)
else:
if isinstance(dy_local, QuantizedTensorBase):
dy_local = dy_local.dequantize(dtype=dtype)
elif dy_local.dtype != dtype:
dy_local = dy_local.to(dtype=dtype)
dy_local = maybe_dequantize(dy_local, dtype)
# Cast weight tensor dtype if needed
if weight is None:
......@@ -307,10 +305,7 @@ class UserbuffersBackwardLinear(FusedOperation):
weight_quantizer.set_usage(columnwise=True)
w = weight_quantizer(w)
else:
if isinstance(w, QuantizedTensorBase):
w = w.dequantize(dtype=dtype)
elif w.dtype != dtype:
w = w.to(dtype=dtype)
w = maybe_dequantize(w, dtype)
# Cast input tensor dtype if needed
x_local = None
......@@ -323,10 +318,7 @@ class UserbuffersBackwardLinear(FusedOperation):
input_quantizer.set_usage(columnwise=True)
x_local = input_quantizer(x_local)
else:
if isinstance(x_local, QuantizedTensorBase):
x_local = x_local.dequantize(dtype=dtype)
elif x_local.dtype != dtype:
x_local = x_local.to(dtype=dtype)
x_local = maybe_dequantize(x_local, dtype)
# dgrad GEMM
dx_local = None
......
......@@ -24,6 +24,7 @@ from ...tensor.quantized_tensor import QuantizedTensorBase, Quantizer
from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer
from ...tensor._internal.float8_tensor_base import Float8TensorBase
from ...utils import canonicalize_device, canonicalize_dtype
from .._common import maybe_dequantize
from ..basic import BasicLinear, Bias, ReduceScatter
from ..op import (
BasicOperation,
......@@ -226,22 +227,16 @@ class UserbuffersForwardLinear(FusedOperation):
input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad)
x_local = input_quantizer(x_local)
else:
if isinstance(x_local, QuantizedTensorBase):
x_local = x_local.dequantize(dtype=dtype)
if x_local.dtype != dtype:
x_local = x_local.to(dtype=dtype)
x_local = maybe_dequantize(x_local, dtype)
x = x_local
# Initialize weight tensor
w = weight
w_is_quantized = isinstance(w, QuantizedTensorBase)
if with_quantized_compute and not w_is_quantized:
if not with_quantized_compute:
w = maybe_dequantize(w, dtype)
elif with_quantized_compute and not isinstance(w, QuantizedTensorBase):
weight_quantizer.set_usage(rowwise=True, columnwise=input_requires_grad)
w = weight_quantizer(w)
elif not with_quantized_compute and w_is_quantized:
w = w.dequantize()
if not with_quantized_compute and w.dtype != dtype:
w = w.to(dtype=dtype)
# Construct output tensor if needed
reduce_scatter_output = None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment