Unverified Commit 44574def authored by Selvaraj Anandaraj's avatar Selvaraj Anandaraj Committed by GitHub
Browse files

Fixed offloading for PyT version/ Added Attention activation offloading...


Fixed offloading for PyT version/ Added Attention activation offloading support/ Native FP8 support (#632)

* Fixed offloading for PyT version/ Added Attention activation offloading support/ Native FP8 support
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Removed activation offloading for fused attention
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Fixed the illegal memory access issue for activation offloading of attention
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Removed the version guard
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Pipeline failures fix
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Fixed lint erros
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

* Lint error fix
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>

---------
Signed-off-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>
Co-authored-by: default avatarSelvaraj Anandaraj <selvaraja@login-eos01.eos.clusters.nvidia.com>
parent 4077ccc1
...@@ -1752,6 +1752,14 @@ class FlashAttention(torch.nn.Module): ...@@ -1752,6 +1752,14 @@ class FlashAttention(torch.nn.Module):
deterministic=self.deterministic deterministic=self.deterministic
) )
else: else:
from .cpu_offload import CPUOffloadEnabled
if CPUOffloadEnabled:
tensor_list = [query_layer, key_layer, value_layer, cu_seqlens_q, cu_seqlens_kv]
for tensor in tensor_list:
if tensor is not None:
tensor.activation_offloading = True
with self.attention_dropout_ctx(): with self.attention_dropout_ctx():
fa_optional_forward_kwargs = {} fa_optional_forward_kwargs = {}
if _flash_attn_2_3_plus: if _flash_attn_2_3_plus:
...@@ -1938,6 +1946,15 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -1938,6 +1946,15 @@ class FusedAttnFunc(torch.autograd.Function):
attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, attn_scale, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen) rng_gen)
from .cpu_offload import CPUOffloadEnabled
if CPUOffloadEnabled:
tensor_list = [q, k, v, out, cu_seqlens_q, cu_seqlens_kv]
qkv_layout = 'sbhd_sbhd_sbhd'
for tensor in tensor_list:
if tensor is not None:
tensor.activation_offloading = True
ctx.save_for_backward(q, k, v, out, cu_seqlens_q, cu_seqlens_kv) ctx.save_for_backward(q, k, v, out, cu_seqlens_q, cu_seqlens_kv)
ctx.aux_ctx_tensors = aux_ctx_tensors ctx.aux_ctx_tensors = aux_ctx_tensors
ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_q = max_seqlen_q
...@@ -2818,6 +2835,13 @@ class DotProductAttention(torch.nn.Module): ...@@ -2818,6 +2835,13 @@ class DotProductAttention(torch.nn.Module):
assert (not context_parallel), \ assert (not context_parallel), \
"Context parallelism is only implemented with Flash Attention and Fused Attention!" "Context parallelism is only implemented with Flash Attention and Fused Attention!"
from .cpu_offload import CPUOffloadEnabled
if CPUOffloadEnabled:
warnings.warn(
"Attention activation Offloading is only implemented"
"with Flash Attention and Fused Attention!"
)
if _NVTE_DEBUG: if _NVTE_DEBUG:
print("[DotProductAttention]: using unfused DPA") print("[DotProductAttention]: using unfused DPA")
if use_unfused_attention: if use_unfused_attention:
......
...@@ -184,6 +184,7 @@ class SynchronizedGroupOffloadHandler(OffloadHandler): ...@@ -184,6 +184,7 @@ class SynchronizedGroupOffloadHandler(OffloadHandler):
# the tensor back to gpu and deletes the cpu tensor. # the tensor back to gpu and deletes the cpu tensor.
# These will increment whenever `group_commit()` is invoked # These will increment whenever `group_commit()` is invoked
self.current_group, self.tensor_count_current_group = (0, 0) self.current_group, self.tensor_count_current_group = (0, 0)
self.torch_tensor_count = 0
self.tensor_tag_to_state = {} self.tensor_tag_to_state = {}
def on_group_commit_forward(self): def on_group_commit_forward(self):
...@@ -310,24 +311,35 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -310,24 +311,35 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any: def tensor_push(self, tensor: torch.Tensor, **kwargs) -> Any:
# obtain a unique tensor tag
tensor_tag = (self.current_group, self.tensor_count_current_group)
self.tensor_count_current_group += 1
assert tensor_tag not in self.tensor_tag_to_state
if (self.current_group < self.num_offload_group torch_stray_tensor = isinstance(tensor,(torch._subclasses.fake_tensor.FakeTensor,
and self.tensor_need_offloading_checker(tensor)): torch._subclasses.functional_tensor.FunctionalTensor))
# first copy the tensor to tensorbuf, so that the original tensor will not be deleted
tensor_buf = self.get_tensor_buf_for_offloaded_tensor(tensor, tensor_tag) if not torch_stray_tensor:
tensor_buf.copy_(tensor) # obtain a unique tensor tag
if hasattr(tensor,"weight_offloading"): tensor_tag = (self.current_group, self.tensor_count_current_group)
tensor_buf.weight_offloading = True self.tensor_count_current_group += 1
if hasattr(tensor,"activation_offloading"): assert tensor_tag not in self.tensor_tag_to_state
tensor_buf.activation_offloading = True
# Here we just save it, and at commit, bulk_offload_group will handle it if (self.current_group < self.num_offload_group
self.tensor_tag_to_state[tensor_tag] = tensor_buf and self.tensor_need_offloading_checker(tensor)):
# first copy the tensor to tensorbuf,
# so that the original tensor will not be deleted
tensor_buf = self.get_tensor_buf_for_offloaded_tensor(tensor, tensor_tag)
tensor_buf.copy_(tensor)
if hasattr(tensor,"weight_offloading"):
tensor_buf.weight_offloading = True
if hasattr(tensor,"activation_offloading"):
tensor_buf.activation_offloading = True
# Here we just save it, and at commit, bulk_offload_group will handle it
self.tensor_tag_to_state[tensor_tag] = tensor_buf
else:
self.tensor_tag_to_state[tensor_tag] = tensor
else: else:
tensor_tag = (-1,self.torch_tensor_count)
self.torch_tensor_count += 1
self.tensor_tag_to_state[tensor_tag] = tensor self.tensor_tag_to_state[tensor_tag] = tensor
return tensor_tag return tensor_tag
def tensor_pop(self, tensor_tag, **kwargs): def tensor_pop(self, tensor_tag, **kwargs):
...@@ -350,6 +362,10 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler): ...@@ -350,6 +362,10 @@ class AsyncDoubleBufferGroupOffloadHandler(SynchronizedGroupOffloadHandler):
# if offload, return the reference to cpu copy # if offload, return the reference to cpu copy
if self.tensor_need_offloading_checker(tensor_on_device): if self.tensor_need_offloading_checker(tensor_on_device):
if hasattr(tensor_on_device,"weight_offloading"):
delattr(tensor_on_device,"weight_offloading")
if hasattr(tensor_on_device,"activation_offloading"):
delattr(tensor_on_device,"activation_offloading")
state = SynchronizedGroupOffloadHandler.offload(tensor_on_device) state = SynchronizedGroupOffloadHandler.offload(tensor_on_device)
self.tensor_tag_to_state[tensor_tag] = state self.tensor_tag_to_state[tensor_tag] = state
......
...@@ -242,7 +242,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -242,7 +242,7 @@ class _LayerNormLinear(torch.autograd.Function):
if cpu_offloading: if cpu_offloading:
if fuse_wgrad_accumulation: if fuse_wgrad_accumulation:
weight.main_grad.weight_offloading = True weight.main_grad.weight_offloading = True
if fp8: if fp8 and weight_t_fp8 is not None:
weight_t_fp8.weight_offloading = True weight_t_fp8.weight_offloading = True
ln_weight.weight_offloading = True ln_weight.weight_offloading = True
weight.weight_offloading = True weight.weight_offloading = True
......
...@@ -424,8 +424,9 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -424,8 +424,9 @@ class _LayerNormMLP(torch.autograd.Function):
if fuse_wgrad_accumulation: if fuse_wgrad_accumulation:
fc1_weight.main_grad.weight_offloading = True fc1_weight.main_grad.weight_offloading = True
fc2_weight.main_grad.weight_offloading = True fc2_weight.main_grad.weight_offloading = True
if fp8: if fp8 and fc1_weight_t_fp8 is not None:
fc1_weight_t_fp8.weight_offloading = True fc1_weight_t_fp8.weight_offloading = True
if fp8 and fc2_weight_t_fp8 is not None:
fc2_weight_t_fp8.weight_offloading = True fc2_weight_t_fp8.weight_offloading = True
ln_weight.weight_offloading = True ln_weight.weight_offloading = True
fc1_weight.weight_offloading = True fc1_weight.weight_offloading = True
......
...@@ -275,7 +275,7 @@ class _Linear(torch.autograd.Function): ...@@ -275,7 +275,7 @@ class _Linear(torch.autograd.Function):
if cpu_offloading: if cpu_offloading:
if fuse_wgrad_accumulation: if fuse_wgrad_accumulation:
weight.main_grad.weight_offloading = True weight.main_grad.weight_offloading = True
if fp8: if fp8 and weight_t_fp8 is not None:
weight_t_fp8.weight_offloading = True weight_t_fp8.weight_offloading = True
weight.weight_offloading = True weight.weight_offloading = True
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment