"projects/git@developer.sourcefind.cn:tsoc/hg-misc-tools.git" did not exist on "abad43bff04154de8ebaab6fd305b684a02e2e5c"
Unverified Commit 44574def authored by Selvaraj Anandaraj's avatar Selvaraj Anandaraj Committed by GitHub
Browse files

Fixed offloading for PyT version/ Added Attention activation offloading...


Fixed offloading for PyT version/ Added Attention activation offloading support/ Native FP8 support (#632)

* Fixed offloading for PyT version/ Added Attention activation offloading support/ Native FP8 support
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Removed activation offloading for fused attention
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Fixed the illegal memory access issue for activation offloading of attention
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Removed the version guard
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Pipeline failures fix
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Fixed lint erros
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Lint error fix
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>
parent 4077ccc1
......@@ -1752,6 +1752,14 @@ class FlashAttention(torch.nn.Module):
deterministic=self.deterministic
)
else:
from .cpu_offload import CPUOffloadEnabled
if CPUOffloadEnabled:
tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv]
for tensor in tensor_list:
if tensor is not None:
tensor.activation_offloading = True
with self.attention_dropout_ctx():
fa_optional_forward_kwargs = {}
if _flash_attn_2_3_plus:
......@@ -1938,6 +1946,15 @@ class FusedAttnFunc(torch.autograd.Function):
attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen)
from .cpu_offload import CPUOffloadEnabled
if CPUOffloadEnabled:
tensor_list = [q, k, v, out, cu_seqlens_q, cu_seqlens_kv]
qkv_layout = 'sbhd_sbhd_sbhd'
for tensor in tensor_list:
if tensor is not None:
tensor.activation_offloading = True
ctx.save_for_backward(q, k, v, out, cu_seqlens_q, cu_seqlens_kv)
ctx.aux_ctx_tensors = aux_ctx_tensors
ctx.max_seqlen_q = max_seqlen_q
......@@ -2818,6 +2835,13 @@ class DotProductAttention(torch.nn.Module):
assert (not context_parallel), \
"Context parallelism is only implemented with Flash Attention and Fused Attention!"
from .cpu_offload import CPUOffloadEnabled
if CPUOffloadEnabled:
warnings.warn(
"Attention activation Offloading is only implemented"
"with Flash Attention and Fused Attention!"
)
if _NVTE_DEBUG:
print("[DotProductAttention]: using unfused DPA")
if use_unfused_attention:
......
......@@ -184,6 +184,7 @@ class SynchronizedGroupOffloadHandler(OffloadHandler):
# the tensor back to gpu and deletes the cpu tensor.
# These will increment whenever `group_commit()` is invoked
self.current_group, self.tensor_count_current_group = (0, 0)
self.torch_tensor_count = 0
self.tensor_tag_to_state = {}
def on_group_commit_forward(self):
......@@ -310,6 +311,11 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:
torch_stray_tensor = isinstance(tensor,(torch._subclasses.fake_tensor.FakeTensor,
torch._subclasses.functional_tensor.FunctionalTensor))
if not torch_stray_tensor:
# obtain a unique tensor tag
tensor_tag = (self.current_group, self.tensor_count_current_group)
self.tensor_count_current_group += 1
......@@ -317,7 +323,8 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
if (self.current_group < self.num_offload_group
and self.tensor_need_offloading_checker(tensor)):
# first copy the tensor to tensorbuf, so that the original tensor will not be deleted
# first copy the tensor to tensorbuf,
# so that the original tensor will not be deleted
tensor_buf = self.get_tensor_buf_for_offloaded_tensor(tensor, tensor_tag)
tensor_buf.copy_(tensor)
if hasattr(tensor,"weight_offloading"):
......@@ -328,6 +335,11 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
self.tensor_tag_to_state[tensor_tag] = tensor_buf
else:
self.tensor_tag_to_state[tensor_tag] = tensor
else:
tensor_tag = (-1,self.torch_tensor_count)
self.torch_tensor_count += 1
self.tensor_tag_to_state[tensor_tag] = tensor
return tensor_tag
def tensor_pop(self, tensor_tag, **kwargs):
......@@ -350,6 +362,10 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
# if offload, return the reference to cpu copy
if self.tensor_need_offloading_checker(tensor_on_device):
if hasattr(tensor_on_device,"weight_offloading"):
delattr(tensor_on_device,"weight_offloading")
if hasattr(tensor_on_device,"activation_offloading"):
delattr(tensor_on_device,"activation_offloading")
state = SynchronizedGroupOffloadHandler.offload(tensor_on_device)
self.tensor_tag_to_state[tensor_tag] = state
......
......@@ -242,7 +242,7 @@ class _LayerNormLinear(torch.autograd.Function):
if cpu_offloading:
if fuse_wgrad_accumulation:
weight.main_grad.weight_offloading = True
if fp8:
if fp8 and weight_t_fp8 is not None:
weight_t_fp8.weight_offloading = True
ln_weight.weight_offloading = True
weight.weight_offloading = True
......
......@@ -424,8 +424,9 @@ class _LayerNormMLP(torch.autograd.Function):
if fuse_wgrad_accumulation:
fc1_weight.main_grad.weight_offloading = True
fc2_weight.main_grad.weight_offloading = True
if fp8:
if fp8 and fc1_weight_t_fp8 is not None:
fc1_weight_t_fp8.weight_offloading = True
if fp8 and fc2_weight_t_fp8 is not None:
fc2_weight_t_fp8.weight_offloading = True
ln_weight.weight_offloading = True
fc1_weight.weight_offloading = True
......
......@@ -275,7 +275,7 @@ class _Linear(torch.autograd.Function):
if cpu_offloading:
if fuse_wgrad_accumulation:
weight.main_grad.weight_offloading = True
if fp8:
if fp8 and weight_t_fp8 is not None:
weight_t_fp8.weight_offloading = True
weight.weight_offloading = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment