"include/vscode:/vscode.git/clone" did not exist on "44a8cf19bc84f996cf16b49b24b44333a73cdd43"
Unverified Commit 5612ba78 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

fix TP case to rm redundant AR (#3)


Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent a2caec1e
...@@ -498,6 +498,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -498,6 +498,7 @@ class _LayerNormLinear(torch.autograd.Function):
fuse_wgrad_accumulation: bool, fuse_wgrad_accumulation: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
sequence_parallel: bool, sequence_parallel: bool,
tensor_parallel: bool,
activation_dtype: torch.dtype, activation_dtype: torch.dtype,
parallel_mode: Union[str, None], parallel_mode: Union[str, None],
return_layernorm_output: bool, return_layernorm_output: bool,
...@@ -613,6 +614,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -613,6 +614,7 @@ class _LayerNormLinear(torch.autograd.Function):
ctx.is_first_microbatch = is_first_microbatch ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group ctx.tp_group = tp_group
...@@ -621,7 +623,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -621,7 +623,7 @@ class _LayerNormLinear(torch.autograd.Function):
# Row Parallel Linear # Row Parallel Linear
if parallel_mode == "row" and sequence_parallel: if parallel_mode == "row" and sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group) out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif parallel_mode == "row": elif parallel_mode == "row" and tensor_parallel:
out, _ = allreduce(out, tp_group) out, _ = allreduce(out, tp_group)
# [*, in_features] -> [*, out_features] except first dimension changes for SP # [*, in_features] -> [*, out_features] except first dimension changes for SP
...@@ -710,7 +712,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -710,7 +712,7 @@ class _LayerNormLinear(torch.autograd.Function):
dgrad, handle = reduce_scatter_along_first_dim( dgrad, handle = reduce_scatter_along_first_dim(
dgrad, ctx.tp_group, async_op=True dgrad, ctx.tp_group, async_op=True
) )
elif ctx.parallel_mode == "column": elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True) dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
if ctx.fp8: if ctx.fp8:
...@@ -768,7 +770,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -768,7 +770,7 @@ class _LayerNormLinear(torch.autograd.Function):
) )
# Column Parallel Linear # Column Parallel Linear
if ctx.parallel_mode == "column" and handle is not None: if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
handle.wait() handle.wait()
# LayerNorm gradient # LayerNorm gradient
...@@ -808,6 +810,7 @@ class _LayerNormLinear(torch.autograd.Function): ...@@ -808,6 +810,7 @@ class _LayerNormLinear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -1045,6 +1048,7 @@ class LayerNormLinear(TransformerEngineBaseModule): ...@@ -1045,6 +1048,7 @@ class LayerNormLinear(TransformerEngineBaseModule):
self.fuse_wgrad_accumulation, self.fuse_wgrad_accumulation,
self.tp_group, self.tp_group,
self.sequence_parallel, self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype, self.activation_dtype,
self.parallel_mode, self.parallel_mode,
self.return_layernorm_output, self.return_layernorm_output,
...@@ -1087,6 +1091,7 @@ class _Linear(torch.autograd.Function): ...@@ -1087,6 +1091,7 @@ class _Linear(torch.autograd.Function):
fuse_wgrad_accumulation: bool, fuse_wgrad_accumulation: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
sequence_parallel: bool, sequence_parallel: bool,
tensor_parallel: bool,
activation_dtype: torch.dtype, activation_dtype: torch.dtype,
parallel_mode: Union[str, None], parallel_mode: Union[str, None],
) -> torch.Tensor: ) -> torch.Tensor:
...@@ -1188,6 +1193,7 @@ class _Linear(torch.autograd.Function): ...@@ -1188,6 +1193,7 @@ class _Linear(torch.autograd.Function):
ctx.is_first_microbatch = is_first_microbatch ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape ctx.inp_shape = inp.shape
ctx.parallel_mode = parallel_mode ctx.parallel_mode = parallel_mode
ctx.tp_group = tp_group ctx.tp_group = tp_group
...@@ -1195,7 +1201,7 @@ class _Linear(torch.autograd.Function): ...@@ -1195,7 +1201,7 @@ class _Linear(torch.autograd.Function):
# Row Parallel Linear # Row Parallel Linear
if parallel_mode == "row" and sequence_parallel: if parallel_mode == "row" and sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group) out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif parallel_mode == "row": elif parallel_mode == "row" and tensor_parallel:
out, _ = allreduce(out, tp_group) out, _ = allreduce(out, tp_group)
# [*, in_features] -> [*, out_features] except first dimension changes for SP # [*, in_features] -> [*, out_features] except first dimension changes for SP
...@@ -1283,7 +1289,7 @@ class _Linear(torch.autograd.Function): ...@@ -1283,7 +1289,7 @@ class _Linear(torch.autograd.Function):
dgrad, handle = reduce_scatter_along_first_dim( dgrad, handle = reduce_scatter_along_first_dim(
dgrad, ctx.tp_group, async_op=True dgrad, ctx.tp_group, async_op=True
) )
elif ctx.parallel_mode == "column": elif ctx.parallel_mode == "column" and ctx.tensor_parallel:
dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True) dgrad, handle = allreduce(dgrad, ctx.tp_group, async_op=True)
if ctx.fp8: if ctx.fp8:
...@@ -1333,7 +1339,7 @@ class _Linear(torch.autograd.Function): ...@@ -1333,7 +1339,7 @@ class _Linear(torch.autograd.Function):
) )
# Column Parallel Linear # Column Parallel Linear
if ctx.parallel_mode == "column" and handle is not None: if ctx.parallel_mode == "column" and ctx.tensor_parallel and handle is not None:
handle.wait() handle.wait()
if not ctx.use_bias: if not ctx.use_bias:
...@@ -1358,6 +1364,7 @@ class _Linear(torch.autograd.Function): ...@@ -1358,6 +1364,7 @@ class _Linear(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -1563,6 +1570,7 @@ class Linear(TransformerEngineBaseModule): ...@@ -1563,6 +1570,7 @@ class Linear(TransformerEngineBaseModule):
self.fuse_wgrad_accumulation, self.fuse_wgrad_accumulation,
self.tp_group, self.tp_group,
self.sequence_parallel, self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype, self.activation_dtype,
self.parallel_mode, self.parallel_mode,
) )
...@@ -1604,6 +1612,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1604,6 +1612,7 @@ class _LayerNormMLP(torch.autograd.Function):
fuse_wgrad_accumulation: bool, fuse_wgrad_accumulation: bool,
tp_group: Union[dist_group_type, None], tp_group: Union[dist_group_type, None],
sequence_parallel: bool, sequence_parallel: bool,
tensor_parallel: bool,
activation_dtype: torch.dtype, activation_dtype: torch.dtype,
return_layernorm_output: bool, return_layernorm_output: bool,
bias_gelu_nvfusion: bool, bias_gelu_nvfusion: bool,
...@@ -1774,6 +1783,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1774,6 +1783,7 @@ class _LayerNormMLP(torch.autograd.Function):
ctx.is_first_microbatch = is_first_microbatch ctx.is_first_microbatch = is_first_microbatch
ctx.use_bias = use_bias ctx.use_bias = use_bias
ctx.sequence_parallel = sequence_parallel ctx.sequence_parallel = sequence_parallel
ctx.tensor_parallel = tensor_parallel
ctx.inp_shape = inp.shape ctx.inp_shape = inp.shape
ctx.tp_group = tp_group ctx.tp_group = tp_group
ctx.bias_gelu_nvfusion = bias_gelu_nvfusion ctx.bias_gelu_nvfusion = bias_gelu_nvfusion
...@@ -1783,7 +1793,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1783,7 +1793,7 @@ class _LayerNormMLP(torch.autograd.Function):
# Row Parallel Linear # Row Parallel Linear
if set_parallel_mode and sequence_parallel: if set_parallel_mode and sequence_parallel:
fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group) fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group)
elif set_parallel_mode: elif set_parallel_mode and tensor_parallel:
fc2_out, _ = allreduce(fc2_out, tp_group) fc2_out, _ = allreduce(fc2_out, tp_group)
# [*, in_features] -> [*, out_features] except first dimension changes for SP # [*, in_features] -> [*, out_features] except first dimension changes for SP
...@@ -1981,7 +1991,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -1981,7 +1991,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_dgrad, handle = reduce_scatter_along_first_dim( fc1_dgrad, handle = reduce_scatter_along_first_dim(
fc1_dgrad, ctx.tp_group, async_op=True fc1_dgrad, ctx.tp_group, async_op=True
) )
elif ctx.set_parallel_mode: elif ctx.set_parallel_mode and ctx.tensor_parallel:
fc1_dgrad, handle = allreduce(fc1_dgrad, ctx.tp_group, async_op=True) fc1_dgrad, handle = allreduce(fc1_dgrad, ctx.tp_group, async_op=True)
if ctx.fp8: if ctx.fp8:
...@@ -2044,7 +2054,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2044,7 +2054,7 @@ class _LayerNormMLP(torch.autograd.Function):
fc1_wgrad, fc1_bias_grad, _ = fc1_wgrad_outputs fc1_wgrad, fc1_bias_grad, _ = fc1_wgrad_outputs
# Column Parallel Linear # Column Parallel Linear
if ctx.set_parallel_mode and handle is not None: if ctx.set_parallel_mode and ctx.tensor_parallel and handle is not None:
handle.wait() handle.wait()
# LayerNorm gradient # LayerNorm gradient
...@@ -2089,6 +2099,7 @@ class _LayerNormMLP(torch.autograd.Function): ...@@ -2089,6 +2099,7 @@ class _LayerNormMLP(torch.autograd.Function):
None, None,
None, None,
None, None,
None,
) )
...@@ -2355,6 +2366,7 @@ class LayerNormMLP(TransformerEngineBaseModule): ...@@ -2355,6 +2366,7 @@ class LayerNormMLP(TransformerEngineBaseModule):
self.fuse_wgrad_accumulation, self.fuse_wgrad_accumulation,
self.tp_group, self.tp_group,
self.sequence_parallel, self.sequence_parallel,
self.tp_size > 1,
self.activation_dtype, self.activation_dtype,
self.return_layernorm_output, self.return_layernorm_output,
self.bias_gelu_nvfusion, self.bias_gelu_nvfusion,
......
...@@ -271,10 +271,7 @@ class MultiHeadAttention(torch.nn.Module): ...@@ -271,10 +271,7 @@ class MultiHeadAttention(torch.nn.Module):
self.tp_size = tp_size self.tp_size = tp_size
self.sequence_parallel = (tp_size > 1) and sequence_parallel self.sequence_parallel = (tp_size > 1) and sequence_parallel
projection_size = kv_channels * num_attention_heads self.hidden_size_per_attention_head = kv_channels
self.hidden_size_per_attention_head = divide(
projection_size, num_attention_heads
)
self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size) self.num_attention_heads_per_partition = divide(num_attention_heads, tp_size)
common_gemm_kwargs = { common_gemm_kwargs = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment