Unverified Commit 8ace813c authored by Kshitij Lakhani's avatar Kshitij Lakhani Committed by GitHub
Browse files

Refactor attention.py part 2 (#1704)



* Move MultiHeadAttention into its own file. Modify tests and files in t_e/pytorch to import from the new MHA module
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Resolving lost MHA changes from PR 1614 as a result of rebase
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Move context parallelism code into it's own file. Modify test and local imports of cp code accordingly
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Move softmax.py frm pytorch/ to pytorch/d_p_a
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Move Unfused and Fused attention to backends.py and some utils functions to pytorch/utils.py
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Resolving lost mark_activation_offload changes from PR 1678 as a result of rebase
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Code clean up
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Refactor attention dir
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Refactor dir structure. Make relevant symbols public in __init__ for attention and d_p_a dirs
Move FA package imports to backends.py
Code cleanup
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Modify tests to import attention modules correctly
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Lint fixes
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Code clean up and fix typo
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Allowing InferenceParams and RoPE imports from attention module and pytorch module
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Allow InferenceParams and RoPE imports via transformer_engine.pytorch and transformer_engine.pytorch.attention modules
Remove unnecessary checks for check_set_window_size in MHA and TL
Reorder backends such that smaller classes at the start and larger ones at the end
Code clean up
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* Reinstating changes from PR 1478 for rope.py lost during rebase conflict resolution
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Fix lint issues
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* nit: Code clean up
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* Make imports leaner
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



---------
Signed-off-by: default avatarKshitij Janardan Lakhani <klakhani@nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
parent 6c942ffd
......@@ -565,7 +565,9 @@ def has_te_modules(network):
"""
from .module import LayerNorm, RMSNorm
from .module.base import TransformerEngineBaseModule
from .attention import UnfusedDotProductAttention, DotProductAttention, MultiheadAttention
from .attention.dot_product_attention.backends import UnfusedDotProductAttention
from .attention.dot_product_attention.dot_product_attention import DotProductAttention
from .attention.multi_head_attention import MultiheadAttention
from .transformer import TransformerLayer
te_classes_list = [
......@@ -1478,7 +1480,9 @@ def _is_te_module(module):
"""
from .module import LayerNorm, RMSNorm
from .module.base import TransformerEngineBaseModule
from .attention import UnfusedDotProductAttention, DotProductAttention, MultiheadAttention
from .attention.dot_product_attention.dot_product_attention import DotProductAttention
from .attention.dot_product_attention.backends import UnfusedDotProductAttention
from .attention.multi_head_attention import MultiheadAttention
from .transformer import TransformerLayer
te_classes_list = [
......
......@@ -536,7 +536,9 @@ def _make_graphed_callables(
# Only Set the FP8 meta for the modules included by forward
continue
fp8_recipe = FP8GlobalStateManager.get_fp8_recipe()
from transformer_engine.pytorch.attention import DotProductAttention
from transformer_engine.pytorch.attention.dot_product_attention import (
DotProductAttention,
)
if (
isinstance(m, DotProductAttention)
......
......@@ -12,11 +12,8 @@ import torch
from transformer_engine.pytorch.module import LayerNormMLP, LayerNorm, RMSNorm
from transformer_engine.debug.pytorch.debug_state import TEDebugState
from transformer_engine.pytorch.attention import (
MultiheadAttention,
)
from transformer_engine.pytorch.dot_product_attention.inference import InferenceParams
from transformer_engine.pytorch.dot_product_attention.utils import check_set_window_size
from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention
from transformer_engine.pytorch.attention.inference import InferenceParams
from transformer_engine.pytorch.jit import (
set_jit_fusion_options,
warmup_jit_bias_dropout_add_all_dtypes,
......@@ -286,11 +283,9 @@ class TransformerLayer(torch.nn.Module):
super().__init__()
self.self_attn_mask_type = self_attn_mask_type
self.window_size = check_set_window_size(self_attn_mask_type, window_size)
self.window_size = window_size
self.enc_dec_attn_mask_type = enc_dec_attn_mask_type
self.enc_dec_window_size = check_set_window_size(
enc_dec_attn_mask_type, enc_dec_window_size
)
self.enc_dec_window_size = enc_dec_window_size
params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad
ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad
......@@ -657,12 +652,10 @@ class TransformerLayer(torch.nn.Module):
self_attn_mask_type = self.self_attn_mask_type
if window_size is None:
window_size = self.window_size
window_size = check_set_window_size(self_attn_mask_type, window_size)
if enc_dec_attn_mask_type is None:
enc_dec_attn_mask_type = self.enc_dec_attn_mask_type
if enc_dec_window_size is None:
enc_dec_window_size = self.enc_dec_window_size
enc_dec_window_size = check_set_window_size(enc_dec_attn_mask_type, enc_dec_window_size)
assert (
self_attn_mask_type in AttnMaskTypes
......
......@@ -7,7 +7,8 @@ from __future__ import annotations
import functools
import math
import os
from typing import Any, Callable, List, Optional, Tuple
from typing import Any, Callable, List, Optional, Tuple, Union
import numpy as np
import torch
import transformer_engine.pytorch.cpp_extensions as ext
......@@ -155,6 +156,184 @@ def split_tensor_along_dim(
return tensor_list
# @klakhani TODO: Consider combining with split_tensor_along_dim() and no_op_cat() and SplitAlongDim
def combine_tensors(
tensors: List[torch.Tensor],
dim: int,
) -> torch.Tensor:
"""Combine tensors along a particular dimension"""
num_tensors = len(tensors)
new_shape = list(tensors[0].shape)
new_shape.insert(dim, num_tensors)
from transformer_engine.pytorch.float8_tensor import Float8Tensor
if isinstance(tensors[0], Float8Tensor):
new_stride = list(tensors[0]._data.stride())
new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors))
combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0]._data.dtype)
combined_tensor.set_(
tensors[0]._data.untyped_storage(),
tensors[0]._data.storage_offset(),
new_shape,
new_stride,
)
combined_tensor = Float8Tensor.make_like(tensors[0], data=combined_tensor, shape=new_shape)
else:
new_stride = list(tensors[0].stride())
new_stride.insert(dim, int(new_stride[dim - 1] / num_tensors))
combined_tensor = torch.Tensor().to(device=tensors[0].device, dtype=tensors[0].dtype)
combined_tensor.set_(
tensors[0].untyped_storage(), tensors[0].storage_offset(), new_shape, new_stride
)
return combined_tensor
class SplitAlongDim(torch.autograd.Function):
"""
Split tensor along given dimension
"""
@staticmethod
def forward(
ctx,
mixed_x_layer: torch.Tensor,
split_dim: int,
split_size_or_sections: Union[int, List[int], Tuple[int]],
squeeze=False,
) -> Tuple[torch.Tensor, ...]:
# pylint: disable=missing-function-docstring
ctx.split_dim = split_dim
ctx.split_size_or_sections = split_size_or_sections
from transformer_engine.pytorch.float8_tensor import Float8Tensor
from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase
if isinstance(mixed_x_layer, Float8TensorBase) and not isinstance(
mixed_x_layer, Float8Tensor
):
return tuple(
Float8TensorBase(
fp8_scale_inv=mixed_x_layer._scale_inv,
fp8_dtype=mixed_x_layer._fp8_dtype,
data=x.squeeze(split_dim) if squeeze else x,
shape=x.squeeze(split_dim).shape if squeeze else x.shape,
quantizer=mixed_x_layer._quantizer,
)
for x in torch.split(
mixed_x_layer._data,
split_size_or_sections=split_size_or_sections,
dim=split_dim,
)
)
if isinstance(mixed_x_layer, Float8Tensor):
return tuple(
Float8Tensor.make_like(
mixed_x_layer,
data=x.squeeze(split_dim) if squeeze else x,
shape=x.squeeze(split_dim).shape if squeeze else x.shape,
)
for x in torch.split(
mixed_x_layer._data,
split_size_or_sections=split_size_or_sections,
dim=split_dim,
)
)
out_list = torch.split(mixed_x_layer, split_size_or_sections, dim=split_dim)
if squeeze:
out_list = [x.squeeze(split_dim) for x in out_list]
return out_list
@staticmethod
def backward(ctx, *grad_outputs):
# pylint: disable=missing-function-docstring
assert len(grad_outputs) > 0, "No gradients received for backprop!"
if isinstance(ctx.split_size_or_sections, (list, tuple)):
split_sizes = ctx.split_size_or_sections
assert len(grad_outputs) == len(
split_sizes
), "Unequal number of gradients vs split sections for backprop!"
if isinstance(ctx.split_size_or_sections, int):
split_sizes = [ctx.split_size_or_sections] * len(grad_outputs)
dims = len(grad_outputs[0].shape)
split_dim = (ctx.split_dim + dims) % dims
from transformer_engine.pytorch.float8_tensor import Float8Tensor
if isinstance(grad_outputs[0], Float8Tensor):
noop_ok = True
strides = grad_outputs[0].stride()
data_ptr = grad_outputs[0]._data.untyped_storage().data_ptr()
shape = list(grad_outputs[0].shape)
for i, tensor in enumerate(grad_outputs):
shape_i = shape
shape_i[split_dim] = split_sizes[i]
offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :])
if (
tensor.stride() != strides
or list(tensor.shape) != shape_i
or tensor._data.untyped_storage().data_ptr() != data_ptr
or tensor.storage_offset() != offset_size
):
noop_ok = False
break
if noop_ok:
ret = torch.Tensor().to(
device=grad_outputs[0].device, dtype=grad_outputs[0]._data.dtype
)
new_shape = list(shape)
new_shape[split_dim] = sum(split_sizes)
ret.set_(
grad_outputs[0]._data.untyped_storage(),
grad_outputs[0]._data.storage_offset(),
new_shape,
strides,
)
return (
Float8Tensor.make_like(grad_outputs[0], data=ret, shape=ret.shape),
None,
None,
)
grad_outputs_data = [x._data for x in grad_outputs]
data = torch.cat(grad_outputs_data, dim=split_dim)
return (
Float8Tensor.make_like(grad_outputs[0], data=data, shape=data.shape),
None,
None,
None,
)
noop_ok = True
strides = grad_outputs[0].stride()
data_ptr = grad_outputs[0].untyped_storage().data_ptr()
shape = list(grad_outputs[0].shape)
for i, tensor in enumerate(grad_outputs):
shape_i = shape
shape_i[split_dim] = split_sizes[i]
offset_size = sum(split_sizes[:i]) * np.prod(shape[split_dim + 1 :])
if (
tensor.stride() != strides
or list(tensor.shape) != shape_i
or tensor.untyped_storage().data_ptr() != data_ptr
or tensor.storage_offset() != offset_size
):
noop_ok = False
break
if noop_ok:
ret = torch.Tensor().to(device=grad_outputs[0].device, dtype=grad_outputs[0].dtype)
new_shape = list(shape)
new_shape[split_dim] = sum(split_sizes)
ret.set_(
grad_outputs[0].untyped_storage(),
grad_outputs[0].storage_offset(),
new_shape,
strides,
)
return ret, None, None
return torch.cat(grad_outputs, dim=split_dim), None, None
def validate_ctx_manager(ctx: Callable) -> None:
"""Checks if passed in object can be used as a context manager."""
try:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment