Commit 356eb36a authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

address review comments

parent aae72ffc
......@@ -302,7 +302,7 @@ def parse_args(extra_args_provider=None, defaults={},
'selective checkpointing granularity'
# disable async_tensor_model_parallel_allreduce when
# model parallel memory optmization is enabled
# model parallel memory optimization is enabled
if args.sequence_parallel:
args.async_tensor_model_parallel_allreduce = False
......@@ -489,7 +489,7 @@ def _add_training_args(parser):
group.add_argument('--checkpoint-granularity', type=str, default=None,
choices=['full', 'selective'],
help='Checkpoint activatins to allow for training '
help='Checkpoint activations to allow for training '
'with larger models, sequences, and batch sizes. '
'It is supported at two granularities 1) full: '
'whole transformer layer is checkpointed, '
......@@ -567,7 +567,7 @@ def _add_training_args(parser):
'check persist_ln_hidden_sizes if your hidden '
'size is supported.')
group.add_argument('--sequence-parallel', action='store_true',
help='Enable sequence parallel optmization.')
help='Enable sequence parallel optimization.')
group.add_argument('--no-gradient-accumulation-fusion',
action='store_false',
help='Disable fusing gradient accumulation to weight '
......
......@@ -220,11 +220,9 @@ class Embedding(MegatronModule):
if self.fp32_residual_connection:
embeddings = embeddings.float()
if self.sequence_parallel:
embeddings = mpu.scatter_to_sequence_parallel_region(embeddings)
# Dropout.
if self.sequence_parallel:
embeddings = mpu.scatter_to_sequence_parallel_region(embeddings)
with mpu.get_cuda_rng_tracker().fork():
embeddings = self.embedding_dropout(embeddings)
else:
......
......@@ -130,21 +130,21 @@ class SwitchMLP(MegatronModule):
self.experts.append(ParallelMLP(init_method, output_layer_init_method))
def forward(self, hidden_states):
# hidden_states: [b, s, h]
b = hidden_states.size(0)
s = hidden_states.size(1)
# hidden_states: [s, b, h]
s = hidden_states.size(0)
b = hidden_states.size(1)
h = hidden_states.size(2)
route = self.router(hidden_states)
route = torch.nn.functional.softmax(route, dim=2)
max_prob, max_ind = torch.max(route, dim=2)
max_prob = torch.unsqueeze(max_prob, 2) # [b s 1]
max_prob = torch.unsqueeze(max_prob, 2) # [s b 1]
# TODO (rprenger) TODO this could be made easier to read
# Converting [b, s, h] to [b*s, h].
# Converting [s, b, h] to [s*b, h].
# Each vector could be routed differently
hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [b*s h]
max_prob = max_prob.view(-1, max_prob.size(2)) # [b*s 1]
max_ind = max_ind.view(-1) # [b*s]
hidden_states = hidden_states.view(-1, hidden_states.size(2)) # [s*b h]
max_prob = max_prob.view(-1, max_prob.size(2)) # [s*b 1]
max_ind = max_ind.view(-1) # [s*b]
output_total = torch.empty_like(hidden_states)
output_bias_total = torch.empty_like(hidden_states)
......@@ -160,14 +160,14 @@ class SwitchMLP(MegatronModule):
output_total = output_total*max_prob
output_bias_total = output_bias_total*max_prob
output_total = output_total.view(b, s, h)
output_bias_total = output_bias_total.view(b, s, h)
output_total = output_total.view(s, b, h)
output_bias_total = output_bias_total.view(s, b, h)
return output_total, output_bias_total
class CoreAttention(MegatronModule):
matmul_input = None
matmul_input_buffer = None
def __init__(self, layer_number,
attn_mask_type=AttnMaskType.padding):
......@@ -235,8 +235,8 @@ class CoreAttention(MegatronModule):
output_size[0] * output_size[1], -1)
# preallocting input tensor: [b * np, sq, sk]
if CoreAttention.matmul_input is None:
CoreAttention.matmul_input = torch.empty(
if CoreAttention.matmul_input_buffer is None:
CoreAttention.matmul_input_buffer = torch.empty(
output_size[0]*output_size[1],
output_size[2],
output_size[3],
......@@ -245,7 +245,7 @@ class CoreAttention(MegatronModule):
# Raw attention scores. [b * np, sq, sk]
matmul_result = torch.baddbmm(
CoreAttention.matmul_input,
CoreAttention.matmul_input_buffer,
query_layer.transpose(0, 1), # [b * np, sq, hn]
key_layer.transpose(0, 1).transpose(1, 2), # [b * np, hn, sk]
beta=0.0, alpha=(1.0/self.norm_factor))
......@@ -311,7 +311,7 @@ class CoreAttention(MegatronModule):
class ParallelAttention(MegatronModule):
"""Parallel self-attention layer abstract class.
Self-attention layer takes input with size [b, s, h]
Self-attention layer takes input with size [s, b, h]
and returns output of the same size.
"""
......@@ -529,7 +529,7 @@ def bias_dropout_add_fused_inference(x: torch.Tensor,
class ParallelTransformerLayer(MegatronModule):
"""A single transformer layer.
Transformer layer takes input with size [b, s, h] and returns an
Transformer layer takes input with size [s, b, h] and returns an
output of the same size.
"""
......@@ -603,7 +603,7 @@ class ParallelTransformerLayer(MegatronModule):
def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
# hidden_states: [b, s, h]
# hidden_states: [s, b, h]
# Layer norm at the beginning of the transformer layer.
layernorm_output = self.input_layernorm(hidden_states)
......@@ -882,6 +882,8 @@ class ParallelTransformer(MegatronModule):
def forward(self, hidden_states, attention_mask,
encoder_output=None, enc_dec_attn_mask=None,
inference_params=None):
# hidden_states: [s, b, h]
# Checks.
if inference_params:
assert self.checkpoint_granularity is None, \
......
......@@ -38,7 +38,7 @@ def _split_along_last_dim(input_):
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size==1:
if world_size == 1:
return input_
# Split along last dimension.
......@@ -57,15 +57,16 @@ def _split_along_first_dim(input_):
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size==1:
if world_size == 1:
return input_
# Split along first dimension.
dim_size = input_.size()[0]
assert dim_size % world_size == 0
assert dim_size % world_size == 0, \
"First dimension of the tensor should be divisible by tensor parallel size"
local_dim_size = dim_size // world_size
rank = get_tensor_model_parallel_rank()
dim_offset = rank * (local_dim_size)
dim_offset = rank * local_dim_size
output = input_[dim_offset:dim_offset+local_dim_size].contiguous()
......@@ -77,7 +78,7 @@ def _gather_along_last_dim(input_):
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size==1:
if world_size == 1:
return input_
# Size and dimension.
......@@ -99,7 +100,7 @@ def _gather_along_first_dim(input_):
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if world_size==1:
if world_size == 1:
return input_
dim_size = list(input_.size())
......@@ -116,11 +117,13 @@ def _reduce_scatter_along_first_dim(input_):
"""Reduce-scatter the input tensor across model parallel group."""
world_size = get_tensor_model_parallel_world_size()
# Bypass the function if we are using only 1 GPU.
if get_tensor_model_parallel_world_size() == 1:
if world_size == 1:
return input_
dim_size = list(input_.size())
assert dim_size[0] % world_size == 0
assert dim_size[0] % world_size == 0, \
"First dimension of the tensor should be divisible by tensor parallel size"
dim_size[0] = dim_size[0] // world_size
output = torch.empty(dim_size, dtype=input_.dtype,
......@@ -130,12 +133,6 @@ def _reduce_scatter_along_first_dim(input_):
return output
def _reduce_scatter_along_last_dim(input_):
output = _reduce(input_)
output = _split_along_last_dim(output)
return output
class _CopyToModelParallelRegion(torch.autograd.Function):
"""Pass the input to the model parallel region."""
......
......@@ -13,7 +13,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from apex.optimizers import FusedAdam as Adam
from apex.optimizers import FusedSGD as SGD
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment