Unverified Commit 5612ba78 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

fix TP case to rm redundant AR (#3)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent a2caec1e
......@@ -498,6 +498,7 @@ class _LayerNormLinear(torch.autograd.Function):
fuse_wgrad_accumulation: bool,
tp_group: Union[dist_group_type, None],
sequence_parallel: bool,
tensor_parallel: bool,
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
return_layernorm_output: bool,
......@@ -613,6 +614,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group
......@@ -621,7 +623,7 @@ class _LayerNormLinear(torch.autograd.Function):
# Row Parallel Linear
if parallel_mode == "row" and sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif parallel_mode == "row":
elif parallel_mode == "row" and tensor_parallel:
out, _ = allreduce(out, tp_group)
# [*, in_features] -> [*, out_features] except first dimension changes for SP
......@@ -710,7 +712,7 @@ class _LayerNormLinear(torch.autograd.Function):
dgrad, handle = reduce_scatter_along_first_dim(
dgrad, ctx.tp_group, async_op=True
)
elif ctx.parallel_mode == "column":
elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
if ctx.fp8:
......@@ -768,7 +770,7 @@ class _LayerNormLinear(torch.autograd.Function):
)
# Column Parallel Linear
if ctx.parallel_mode == "column" and handle is not None:
if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
handle.wait()
# LayerNorm gradient
......@@ -808,6 +810,7 @@ class _LayerNormLinear(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -1045,6 +1048,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fuse_wgrad_accumulation,
self.tp_group,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
self.return_layernorm_output,
......@@ -1087,6 +1091,7 @@ class _Linear(torch.autograd.Function):
fuse_wgrad_accumulation: bool,
tp_group: Union[dist_group_type, None],
sequence_parallel: bool,
tensor_parallel: bool,
activation_dtype: torch.dtype,
parallel_mode: Union[str, None],
) -> torch.Tensor:
......@@ -1188,6 +1193,7 @@ class _Linear(torch.autograd.Function):
ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group
......@@ -1195,7 +1201,7 @@ class _Linear(torch.autograd.Function):
# Row Parallel Linear
if parallel_mode == "row" and sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif parallel_mode == "row":
elif parallel_mode == "row" and tensor_parallel:
out, _ = allreduce(out, tp_group)
# [*, in_features] -> [*, out_features] except first dimension changes for SP
......@@ -1283,7 +1289,7 @@ class _Linear(torch.autograd.Function):
dgrad, handle = reduce_scatter_along_first_dim(
dgrad, ctx.tp_group, async_op=True
)
elif ctx.parallel_mode == "column":
elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
if ctx.fp8:
......@@ -1333,7 +1339,7 @@ class _Linear(torch.autograd.Function):
)
# Column Parallel Linear
if ctx.parallel_mode == "column" and handle is not None:
if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
handle.wait()
if not ctx.use_bias:
......@@ -1358,6 +1364,7 @@ class _Linear(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -1563,6 +1570,7 @@ class Linear(TransformerEngineBaseModule):
self.fuse_wgrad_accumulation,
self.tp_group,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.parallel_mode,
)
......@@ -1604,6 +1612,7 @@ class _LayerNormMLP(torch.autograd.Function):
fuse_wgrad_accumulation: bool,
tp_group: Union[dist_group_type, None],
sequence_parallel: bool,
tensor_parallel: bool,
activation_dtype: torch.dtype,
return_layernorm_output: bool,
bias_gelu_nvfusion: bool,
......@@ -1774,6 +1783,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape
ctx.tp_group = tp_group
ctx.bias_gelu_nvfusion = bias_gelu_nvfusion
......@@ -1783,7 +1793,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Row Parallel Linear
if set_parallel_mode and sequence_parallel:
fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group)
elif set_parallel_mode:
elif set_parallel_mode and tensor_parallel:
fc2_out, _ = allreduce(fc2_out, tp_group)
# [*, in_features] -> [*, out_features] except first dimension changes for SP
......@@ -1981,7 +1991,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_dgrad, handle = reduce_scatter_along_first_dim(
fc1_dgrad, ctx.tp_group, async_op=True
)
elif ctx.set_parallel_mode:
elif ctx.set_parallel_mode and ctx.tensor_parallel:
fc1_dgrad, handle = allreduce(fc1_dgrad, ctx.tp_group, async_op=True)
if ctx.fp8:
......@@ -2044,7 +2054,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_wgrad, fc1_bias_grad, _ = fc1_wgrad_outputs
# Column Parallel Linear
if ctx.set_parallel_mode and handle is not None:
if ctx.set_parallel_mode and ctx.tensor_parallel and handle is not None:
handle.wait()
# LayerNorm gradient
......@@ -2089,6 +2099,7 @@ class _LayerNormMLP(torch.autograd.Function):
None,
None,
None,
None,
)
......@@ -2355,6 +2366,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.fuse_wgrad_accumulation,
self.tp_group,
self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype,
self.return_layernorm_output,
self.bias_gelu_nvfusion,
......
......@@ -271,10 +271,7 @@ class MultiHeadAttention(torch.nn.Module):
self.tp_size = tp_size
self.sequence_parallel = (tp_size > 1) and sequence_parallel
projection_size = kv_channels * num_attention_heads
self.hidden_size_per_attention_head = divide(
projection_size, num_attention_heads
)
self.hidden_size_per_attention_head = kv_channels
self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size)
common_gemm_kwargs = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment