Commit 356eb36a authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

address review comments

parent aae72ffc
...@@ -302,7 +302,7 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -302,7 +302,7 @@ def parse_args(extra_args_provider=None, defaults={},
'selective checkpointing granularity' 'selective checkpointing granularity'
# disable async_tensor_model_parallel_allreduce when # disable async_tensor_model_parallel_allreduce when
# model parallel memory optmization is enabled # model parallel memory optimization is enabled
if args.sequence_parallel: if args.sequence_parallel:
args.async_tensor_model_parallel_allreduce = False args.async_tensor_model_parallel_allreduce = False
...@@ -489,7 +489,7 @@ def _add_training_args(parser): ...@@ -489,7 +489,7 @@ def _add_training_args(parser):
group.add_argument('--checkpoint-granularity', type=str, default=None, group.add_argument('--checkpoint-granularity', type=str, default=None,
choices=['full', 'selective'], choices=['full', 'selective'],
help='Checkpoint activatins to allow for training ' help='Checkpoint activations to allow for training '
'with larger models, sequences, and batch sizes. ' 'with larger models, sequences, and batch sizes. '
'It is supported at two granularities 1) full: ' 'It is supported at two granularities 1) full: '
'whole transformer layer is checkpointed, ' 'whole transformer layer is checkpointed, '
...@@ -567,7 +567,7 @@ def _add_training_args(parser): ...@@ -567,7 +567,7 @@ def _add_training_args(parser):
'check persist_ln_hidden_sizes if your hidden ' 'check persist_ln_hidden_sizes if your hidden '
'size is supported.') 'size is supported.')
group.add_argument('--sequence-parallel', action='store_true', group.add_argument('--sequence-parallel', action='store_true',
help='Enable sequence parallel optmization.') help='Enable sequence parallel optimization.')
group.add_argument('--no-gradient-accumulation-fusion', group.add_argument('--no-gradient-accumulation-fusion',
action='store_false', action='store_false',
help='Disable fusing gradient accumulation to weight ' help='Disable fusing gradient accumulation to weight '
......
...@@ -220,11 +220,9 @@ class Embedding(MegatronModule): ...@@ -220,11 +220,9 @@ class Embedding(MegatronModule):
if self.fp32_residual_connection: if self.fp32_residual_connection:
embeddings = embeddings.float() embeddings = embeddings.float()
if self.sequence_parallel:
embeddings = mpu.scatter_to_sequence_parallel_region(embeddings)
# Dropout. # Dropout.
if self.sequence_parallel: if self.sequence_parallel:
embeddings = mpu.scatter_to_sequence_parallel_region(embeddings)
with mpu.get_cuda_rng_tracker().fork(): with mpu.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings) embeddings = self.embedding_dropout(embeddings)
else: else:
......
...@@ -130,21 +130,21 @@ class SwitchMLP(MegatronModule): ...@@ -130,21 +130,21 @@ class SwitchMLP(MegatronModule):
self.experts.append(ParallelMLP(init_method, output_layer_init_method)) self.experts.append(ParallelMLP(init_method, output_layer_init_method))
def forward(self, hidden_states): def forward(self, hidden_states):
# hidden_states: [b, s, h] # hidden_states: [s, b, h]
b = hidden_states.size(0) s = hidden_states.size(0)
s = hidden_states.size(1) b = hidden_states.size(1)
h = hidden_states.size(2) h = hidden_states.size(2)
route = self.router(hidden_states) route = self.router(hidden_states)
route = torch.nn.functional.softmax(route, dim=2) route = torch.nn.functional.softmax(route, dim=2)
max_prob, max_ind = torch.max(route, dim=2) max_prob, max_ind = torch.max(route, dim=2)
max_prob = torch.unsqueeze(max_prob, 2) # [b s 1] max_prob = torch.unsqueeze(max_prob, 2) # [s b 1]
# TODO (rprenger) TODO this could be made easier to read # TODO (rprenger) TODO this could be made easier to read
# Converting [b, s, h] to [b*s, h]. # Converting [s, b, h] to [s*b, h].
# Each vector could be routed differently # Each vector could be routed differently
hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [b*s h] hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [s*b h]
max_prob = max_prob.view(-1, max_prob.size(2)) # [b*s 1] max_prob = max_prob.view(-1, max_prob.size(2)) # [s*b 1]
max_ind = max_ind.view(-1) # [b*s] max_ind = max_ind.view(-1) # [s*b]
output_total = torch.empty_like(hidden_states) output_total = torch.empty_like(hidden_states)
output_bias_total = torch.empty_like(hidden_states) output_bias_total = torch.empty_like(hidden_states)
...@@ -160,14 +160,14 @@ class SwitchMLP(MegatronModule): ...@@ -160,14 +160,14 @@ class SwitchMLP(MegatronModule):
output_total = output_total*max_prob output_total = output_total*max_prob
output_bias_total = output_bias_total*max_prob output_bias_total = output_bias_total*max_prob
output_total = output_total.view(b, s, h) output_total = output_total.view(s, b, h)
output_bias_total = output_bias_total.view(b, s, h) output_bias_total = output_bias_total.view(s, b, h)
return output_total, output_bias_total return output_total, output_bias_total
class CoreAttention(MegatronModule): class CoreAttention(MegatronModule):
matmul_input = None matmul_input_buffer = None
def __init__(self, layer_number, def __init__(self, layer_number,
attn_mask_type=AttnMaskType.padding): attn_mask_type=AttnMaskType.padding):
...@@ -235,8 +235,8 @@ class CoreAttention(MegatronModule): ...@@ -235,8 +235,8 @@ class CoreAttention(MegatronModule):
output_size[0] * output_size[1], -1) output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk] # preallocting input tensor: [b * np, sq, sk]
if CoreAttention.matmul_input is None: if CoreAttention.matmul_input_buffer is None:
CoreAttention.matmul_input = torch.empty( CoreAttention.matmul_input_buffer = torch.empty(
output_size[0]*output_size[1], output_size[0]*output_size[1],
output_size[2], output_size[2],
output_size[3], output_size[3],
...@@ -245,7 +245,7 @@ class CoreAttention(MegatronModule): ...@@ -245,7 +245,7 @@ class CoreAttention(MegatronModule):
# Raw attention scores. [b * np, sq, sk] # Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm( matmul_result = torch.baddbmm(
CoreAttention.matmul_input, CoreAttention.matmul_input_buffer,
query_layer.transpose(0, 1), # [b * np, sq, hn] query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk] key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor)) beta=0.0, alpha=(1.0/self.norm_factor))
...@@ -311,7 +311,7 @@ class CoreAttention(MegatronModule): ...@@ -311,7 +311,7 @@ class CoreAttention(MegatronModule):
class ParallelAttention(MegatronModule): class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class. """Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h] Self-attention layer takes input with size [s, b, h]
and returns output of the same size. and returns output of the same size.
""" """
...@@ -529,7 +529,7 @@ def bias_dropout_add_fused_inference(x: torch.Tensor, ...@@ -529,7 +529,7 @@ def bias_dropout_add_fused_inference(x: torch.Tensor,
class ParallelTransformerLayer(MegatronModule): class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer. """A single transformer layer.
Transformer layer takes input with size [b, s, h] and returns an Transformer layer takes input with size [s, b, h] and returns an
output of the same size. output of the same size.
""" """
...@@ -603,7 +603,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -603,7 +603,7 @@ class ParallelTransformerLayer(MegatronModule):
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None, encoder_output=None, enc_dec_attn_mask=None,
inference_params=None): inference_params=None):
# hidden_states: [b, s, h] # hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer. # Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states) layernorm_output = self.input_layernorm(hidden_states)
...@@ -882,6 +882,8 @@ class ParallelTransformer(MegatronModule): ...@@ -882,6 +882,8 @@ class ParallelTransformer(MegatronModule):
def forward(self, hidden_states, attention_mask, def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None, encoder_output=None, enc_dec_attn_mask=None,
inference_params=None): inference_params=None):
# hidden_states: [s, b, h]
# Checks. # Checks.
if inference_params: if inference_params:
assert self.checkpoint_granularity is None, \ assert self.checkpoint_granularity is None, \
......
...@@ -38,7 +38,7 @@ def _split_along_last_dim(input_): ...@@ -38,7 +38,7 @@ def _split_along_last_dim(input_):
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if world_size==1: if world_size == 1:
return input_ return input_
# Split along last dimension. # Split along last dimension.
...@@ -57,15 +57,16 @@ def _split_along_first_dim(input_): ...@@ -57,15 +57,16 @@ def _split_along_first_dim(input_):
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if world_size==1: if world_size == 1:
return input_ return input_
# Split along first dimension. # Split along first dimension.
dim_size = input_.size()[0] dim_size = input_.size()[0]
assert dim_size % world_size == 0 assert dim_size % world_size == 0, \
"First dimension of the tensor should be divisible by tensor parallel size"
local_dim_size = dim_size // world_size local_dim_size = dim_size // world_size
rank = get_tensor_model_parallel_rank() rank = get_tensor_model_parallel_rank()
dim_offset = rank * (local_dim_size) dim_offset = rank * local_dim_size
output = input_[dim_offset:dim_offset+local_dim_size].contiguous() output = input_[dim_offset:dim_offset+local_dim_size].contiguous()
...@@ -77,7 +78,7 @@ def _gather_along_last_dim(input_): ...@@ -77,7 +78,7 @@ def _gather_along_last_dim(input_):
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if world_size==1: if world_size == 1:
return input_ return input_
# Size and dimension. # Size and dimension.
...@@ -99,7 +100,7 @@ def _gather_along_first_dim(input_): ...@@ -99,7 +100,7 @@ def _gather_along_first_dim(input_):
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if world_size==1: if world_size == 1:
return input_ return input_
dim_size = list(input_.size()) dim_size = list(input_.size())
...@@ -116,11 +117,13 @@ def _reduce_scatter_along_first_dim(input_): ...@@ -116,11 +117,13 @@ def _reduce_scatter_along_first_dim(input_):
"""Reduce-scatter the input tensor across model parallel group.""" """Reduce-scatter the input tensor across model parallel group."""
world_size = get_tensor_model_parallel_world_size() world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU. # Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size() == 1: if world_size == 1:
return input_ return input_
dim_size = list(input_.size()) dim_size = list(input_.size())
assert dim_size[0] % world_size == 0 assert dim_size[0] % world_size == 0, \
"First dimension of the tensor should be divisible by tensor parallel size"
dim_size[0] = dim_size[0] // world_size dim_size[0] = dim_size[0] // world_size
output = torch.empty(dim_size, dtype=input_.dtype, output = torch.empty(dim_size, dtype=input_.dtype,
...@@ -130,12 +133,6 @@ def _reduce_scatter_along_first_dim(input_): ...@@ -130,12 +133,6 @@ def _reduce_scatter_along_first_dim(input_):
return output return output
def _reduce_scatter_along_last_dim(input_):
output = _reduce(input_)
output = _split_along_last_dim(output)
return output
class _CopyToModelParallelRegion(torch.autograd.Function): class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region.""" """Pass the input to the model parallel region."""
......
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import torch
from apex.optimizers import FusedAdam as Adam from apex.optimizers import FusedAdam as Adam
from apex.optimizers import FusedSGD as SGD from apex.optimizers import FusedSGD as SGD
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment