Unverified Commit 7ad130ef authored by Selvaraj Anandaraj's avatar Selvaraj Anandaraj Committed by GitHub
Browse files

Offloading support for multiple attention layouts (#2024)



* Added multi-layout support for attention
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>

* Comment/cleanup
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>

* Bug fix on import time
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci



* fix
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>

---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>
Signed-off-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-ptyche01.ptyche.clusters.nvidia.com>
Co-authored-by: default avatarpre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: default avatarPawel Gadzinski <pgadzinski@nvidia.com>
Co-authored-by: default avatarPaweł Gadziński <62263673+pggPL@users.noreply.github.com>
parent dd9433e7
...@@ -1258,7 +1258,6 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1258,7 +1258,6 @@ class FusedAttnFunc(torch.autograd.Function):
else: else:
tensor_list = [q, k, v, out] tensor_list = [q, k, v, out]
qkv_layout = "sbhd_sbhd_sbhd"
mark_activation_offload(*tensor_list) mark_activation_offload(*tensor_list)
mark_activation_offload(*aux_ctx_tensors) mark_activation_offload(*aux_ctx_tensors)
...@@ -1293,7 +1292,31 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1293,7 +1292,31 @@ class FusedAttnFunc(torch.autograd.Function):
ctx.attn_scale = attn_scale ctx.attn_scale = attn_scale
ctx.dropout_p = dropout_p ctx.dropout_p = dropout_p
ctx.fast_zero_fill = fast_zero_fill ctx.fast_zero_fill = fast_zero_fill
ctx.qkv_layout = qkv_layout
from transformer_engine.pytorch.cpu_offload import (
CPUOffloadedLayer,
)
# If interleaved tensor is offloaded, reloaded tensor will be
# non-interleaved, so we need to modify the QKV layout
# for backward
if CPUOffloadedLayer and CPUOffloadEnabled:
reload_layout = ""
split_list = qkv_layout.split("_")
for split in split_list:
temp_layout = ""
rep_count = 1
for s in split:
if s.isalpha():
temp_layout = temp_layout + s
else:
rep_count = int(s)
for _ in range(rep_count):
reload_layout = reload_layout + temp_layout + "_"
ctx.qkv_layout = reload_layout[:-1]
else:
ctx.qkv_layout = qkv_layout
ctx.attn_bias_type = attn_bias_type ctx.attn_bias_type = attn_bias_type
ctx.attn_mask_type = attn_mask_type ctx.attn_mask_type = attn_mask_type
ctx.softmax_type = softmax_type ctx.softmax_type = softmax_type
......
...@@ -16,6 +16,7 @@ from .tensor.float8_tensor import Float8Tensor ...@@ -16,6 +16,7 @@ from .tensor.float8_tensor import Float8Tensor
__all__ = ["get_cpu_offload_context"] __all__ = ["get_cpu_offload_context"]
CPUOffloadEnabled = False CPUOffloadEnabled = False
CPUOffloadedLayer = False
def mark_activation_offload(*tensors): def mark_activation_offload(*tensors):
...@@ -353,6 +354,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -353,6 +354,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
self.h2d_stream = torch.cuda.Stream() self.h2d_stream = torch.cuda.Stream()
def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:
global CPUOffloadedLayer
torch_stray_tensor = isinstance( torch_stray_tensor = isinstance(
tensor, tensor,
...@@ -408,6 +410,11 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -408,6 +410,11 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
tensor.clear() tensor.clear()
else: else:
self.tensor_tag_to_buf[tensor_tag] = t self.tensor_tag_to_buf[tensor_tag] = t
# Needed to differentiate non offloaded layer's attention
# QKV layout of attention of non-offloaded layer needs
# to be modified while reloading
CPUOffloadedLayer = True
else: else:
tensor_tag = (-1, self.torch_tensor_count) tensor_tag = (-1, self.torch_tensor_count)
self.torch_tensor_count += 1 self.torch_tensor_count += 1
...@@ -417,6 +424,8 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -417,6 +424,8 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
def tensor_pop(self, tensor_tag, **kwargs): def tensor_pop(self, tensor_tag, **kwargs):
"""Tensor pop.""" """Tensor pop."""
global CPUOffloadedLayer
assert tensor_tag in self.tensor_tag_to_state assert tensor_tag in self.tensor_tag_to_state
tensor = self.tensor_tag_to_state.pop(tensor_tag) tensor = self.tensor_tag_to_state.pop(tensor_tag)
...@@ -480,6 +489,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -480,6 +489,7 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
def synchronize_on_group_commit_forward(self, current_group): def synchronize_on_group_commit_forward(self, current_group):
"""Synchronize on group commit forward.""" """Synchronize on group commit forward."""
global CPUOffloadedLayer
# For the first group, kickstart the offload after we have # For the first group, kickstart the offload after we have
# the first compute completion # the first compute completion
...@@ -528,6 +538,9 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -528,6 +538,9 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
# Increment the offload group count to keep track # Increment the offload group count to keep track
self.offloaded_group_count += 1 self.offloaded_group_count += 1
if current_group == (self.num_offload_group - 1):
CPUOffloadedLayer = False
if not self.double_buffer_created: if not self.double_buffer_created:
# Creating second copy of double buffer for tensors that are offloaded # Creating second copy of double buffer for tensors that are offloaded
if current_group == (self.num_layers - 1): if current_group == (self.num_layers - 1):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment