Unverified Commit 7f1d604f authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

[PyTorch] Fix tp_group_initialized error (#819)



remove tp_size/tp_group as amax reduction is handled by fp8_group()
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent 38c01c8b
...@@ -2153,7 +2153,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -2153,7 +2153,7 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
def forward(ctx, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, attn_bias, attn_scale, def forward(ctx, is_training, max_seqlen, cu_seqlens, qkv, qkv_dtype, attn_bias, attn_scale,
dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type, dropout_p, fast_zero_fill, qkv_layout, attn_bias_type, attn_mask_type,
rng_gen, fused_attention_backend, use_FAv2_bwd, rng_gen, fused_attention_backend, use_FAv2_bwd,
fp8, fp8_meta, tp_size, tp_group): fp8, fp8_meta):
if fp8: if fp8:
if _NVTE_DEBUG: if _NVTE_DEBUG:
print('[DotProductAttention]: using FP8 forward') print('[DotProductAttention]: using FP8 forward')
...@@ -2227,8 +2227,6 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function): ...@@ -2227,8 +2227,6 @@ class FusedAttnFunc_qkvpacked(torch.autograd.Function):
qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None) qkvo_tensors = (qkv, out_save) if not ctx.fp8 else (None, None)
ctx.save_for_backward(*qkvo_tensors, cu_seqlens, *fp8_tensors) ctx.save_for_backward(*qkvo_tensors, cu_seqlens, *fp8_tensors)
ctx.fp8_meta = fp8_meta ctx.fp8_meta = fp8_meta
ctx.tp_size = tp_size
ctx.tp_group = tp_group
ctx.aux_ctx_tensors = aux_ctx_tensors ctx.aux_ctx_tensors = aux_ctx_tensors
ctx.max_seqlen = max_seqlen ctx.max_seqlen = max_seqlen
ctx.qkv_dtype = qkv_dtype ctx.qkv_dtype = qkv_dtype
...@@ -2349,7 +2347,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -2349,7 +2347,7 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, kv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, q, kv, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill,
qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend, qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend,
use_FAv2_bwd, fp8, fp8_meta, tp_size, tp_group): use_FAv2_bwd, fp8, fp8_meta):
if fp8: if fp8:
if _NVTE_DEBUG: if _NVTE_DEBUG:
print('[DotProductAttention]: using FP8 forward') print('[DotProductAttention]: using FP8 forward')
...@@ -2430,8 +2428,6 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function): ...@@ -2430,8 +2428,6 @@ class FusedAttnFunc_kvpacked(torch.autograd.Function):
qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None) qkvo_tensors = (q, kv, out_save) if not ctx.fp8 else (None, None, None)
ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, *fp8_tensors) ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, *fp8_tensors)
ctx.fp8_meta = fp8_meta ctx.fp8_meta = fp8_meta
ctx.tp_size = tp_size
ctx.tp_group = tp_group
ctx.aux_ctx_tensors = aux_ctx_tensors ctx.aux_ctx_tensors = aux_ctx_tensors
ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_kv = max_seqlen_kv ctx.max_seqlen_kv = max_seqlen_kv
...@@ -2566,7 +2562,7 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -2566,7 +2562,7 @@ class FusedAttnFunc(torch.autograd.Function):
def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv, def forward(ctx, is_training, max_seqlen_q, max_seqlen_kv, cu_seqlens_q, cu_seqlens_kv,
q, k, v, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill, q, k, v, qkv_dtype, attn_bias, attn_scale, dropout_p, fast_zero_fill,
qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend, qkv_layout, attn_bias_type, attn_mask_type, rng_gen, fused_attention_backend,
use_FAv2_bwd, fp8, fp8_meta, tp_size, tp_group): use_FAv2_bwd, fp8, fp8_meta):
if fp8: if fp8:
if _NVTE_DEBUG: if _NVTE_DEBUG:
print('[DotProductAttention]: using FP8 forward') print('[DotProductAttention]: using FP8 forward')
...@@ -2704,8 +2700,6 @@ class FusedAttnFunc(torch.autograd.Function): ...@@ -2704,8 +2700,6 @@ class FusedAttnFunc(torch.autograd.Function):
qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None) qkvo_tensors = (q, k, v, out_save) if not ctx.fp8 else (None, None, None, None)
ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, *fp8_tensors) ctx.save_for_backward(*qkvo_tensors, cu_seqlens_q, cu_seqlens_kv, *fp8_tensors)
ctx.fp8_meta = fp8_meta ctx.fp8_meta = fp8_meta
ctx.tp_size = tp_size
ctx.tp_group = tp_group
ctx.aux_ctx_tensors = aux_ctx_tensors ctx.aux_ctx_tensors = aux_ctx_tensors
ctx.max_seqlen_q = max_seqlen_q ctx.max_seqlen_q = max_seqlen_q
ctx.max_seqlen_kv = max_seqlen_kv ctx.max_seqlen_kv = max_seqlen_kv
...@@ -2907,8 +2901,6 @@ class FusedAttention(TransformerEngineBaseModule): ...@@ -2907,8 +2901,6 @@ class FusedAttention(TransformerEngineBaseModule):
attention_type: str = "self", attention_type: str = "self",
layer_number: Optional[int] = None, layer_number: Optional[int] = None,
deterministic: bool = False, deterministic: bool = False,
tp_size: int = 1,
tp_group: Optional[dist_group_type] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -2935,9 +2927,6 @@ class FusedAttention(TransformerEngineBaseModule): ...@@ -2935,9 +2927,6 @@ class FusedAttention(TransformerEngineBaseModule):
if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1": if os.environ["NVTE_FUSED_ATTN_FORCE_WORKSPACE_OPT"] == "1":
os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1" os.environ["CUDNN_FRONTEND_ATTN_DP_WORKSPACE_LIMIT"] = "-1"
self.tp_size = tp_size
self.tp_group = tp_group
def get_fp8_weights_scratchpad( def get_fp8_weights_scratchpad(
self, self,
is_first_microbatch: Union[bool, None], is_first_microbatch: Union[bool, None],
...@@ -3092,8 +3081,6 @@ class FusedAttention(TransformerEngineBaseModule): ...@@ -3092,8 +3081,6 @@ class FusedAttention(TransformerEngineBaseModule):
use_FAv2_bwd, use_FAv2_bwd,
self.fp8 and self.fp8_meta["recipe"].fp8_dpa, self.fp8 and self.fp8_meta["recipe"].fp8_dpa,
self.fp8_meta, self.fp8_meta,
self.tp_size,
self.tp_group,
) )
# ...hd -> ...(hd) # ...hd -> ...(hd)
...@@ -3292,9 +3279,7 @@ class DotProductAttention(torch.nn.Module): ...@@ -3292,9 +3279,7 @@ class DotProductAttention(torch.nn.Module):
attention_type=attention_type, attention_type=attention_type,
layer_number=layer_number, layer_number=layer_number,
deterministic=self.deterministic, deterministic=self.deterministic,
**attn_kwargs, **attn_kwargs)
tp_size=self.tp_size,
tp_group=self.tp_group)
self.unfused_attention = UnfusedDotProductAttention( self.unfused_attention = UnfusedDotProductAttention(
norm_factor, **attn_kwargs, layer_number=layer_number) norm_factor, **attn_kwargs, layer_number=layer_number)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment