Commit 53931b8b authored by Vijay Korthikanti's avatar Vijay Korthikanti
Browse files

address review comments

parent 8acbbe25
...@@ -246,14 +246,14 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -246,14 +246,14 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.fp16 or args.bf16, \ assert args.fp16 or args.bf16, \
'residual connection in fp32 only supported when using fp16 or bf16.' 'residual connection in fp32 only supported when using fp16 or bf16.'
if args.wd_incr_style == 'constant': if args.weight_decay_incr_style == 'constant':
assert args.start_wd is None assert args.start_weight_decay is None
assert args.end_wd is None assert args.end_weight_decay is None
args.start_wd = args.weight_decay args.start_weight_decay = args.weight_decay
args.end_wd = args.weight_decay args.end_weight_decay = args.weight_decay
else: else:
assert args.start_wd is not None assert args.start_weight_decay is not None
assert args.end_wd is not None assert args.end_weight_decay is not None
TORCH_MAJOR = int(torch.__version__.split('.')[0]) TORCH_MAJOR = int(torch.__version__.split('.')[0])
TORCH_MINOR = int(torch.__version__.split('.')[1]) TORCH_MINOR = int(torch.__version__.split('.')[1])
...@@ -404,11 +404,11 @@ def _add_regularization_args(parser): ...@@ -404,11 +404,11 @@ def _add_regularization_args(parser):
help='Dropout probability for hidden state transformer.') help='Dropout probability for hidden state transformer.')
group.add_argument('--weight-decay', type=float, default=0.01, group.add_argument('--weight-decay', type=float, default=0.01,
help='Weight decay coefficient for L2 regularization.') help='Weight decay coefficient for L2 regularization.')
group.add_argument('--start-wd', type=float, group.add_argument('--start-weight-decay', type=float,
help='Initial weight decay coefficient for L2 regularization.') help='Initial weight decay coefficient for L2 regularization.')
group.add_argument('--end-wd', type=float, group.add_argument('--end-weight-decay', type=float,
help='End of run weight decay coefficient for L2 regularization.') help='End of run weight decay coefficient for L2 regularization.')
group.add_argument('--wd-incr-style', type=str, default='constant', group.add_argument('--weight-decay-incr-style', type=str, default='constant',
choices=['constant', 'linear', 'cosine'], choices=['constant', 'linear', 'cosine'],
help='Weight decay increment function.') help='Weight decay increment function.')
group.add_argument('--clip-grad', type=float, default=1.0, group.add_argument('--clip-grad', type=float, default=1.0,
......
...@@ -49,20 +49,20 @@ class DropPath(MegatronModule): ...@@ -49,20 +49,20 @@ class DropPath(MegatronModule):
(when applied in main path of residual blocks). (when applied in main path of residual blocks).
""" """
def __init__(self, drop_prob=None): def __init__(self, drop_prob=0.):
super(DropPath, self).__init__() super(DropPath, self).__init__()
self.drop_prob = drop_prob self.drop_prob = drop_prob
def forward(self, x): def forward(self, hidden_state):
if self.drop_prob == 0. or not self.training: if self.drop_prob == 0. or not self.training:
return x return hidden_state
keep_prob = 1 - self.drop_prob keep_prob = 1 - self.drop_prob
# work with diff dim tensors, not just 2D ConvNets # work with diff dim tensors, not just 2D ConvNets
shape = (x.shape[0],) + (1,) * (x.ndim - 1) shape = (hidden_state.shape[0],) + (1,) * (hidden_state.ndim - 1)
random_tensor = keep_prob + \ random_tensor = keep_prob + \
torch.rand(shape, dtype=x.dtype, device=x.device) torch.rand(shape, dtype=hidden_state.dtype, device=hidden_state.device)
random_tensor.floor_() # binarize random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor output = hidden_state.div(keep_prob) * random_tensor
return output return output
...@@ -437,7 +437,6 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -437,7 +437,6 @@ class ParallelTransformerLayer(MegatronModule):
super(ParallelTransformerLayer, self).__init__() super(ParallelTransformerLayer, self).__init__()
self.layer_number = layer_number self.layer_number = layer_number
self.layer_type = layer_type self.layer_type = layer_type
self.drop_path_rate = drop_path_rate
self.apply_residual_connection_post_layernorm \ self.apply_residual_connection_post_layernorm \
= args.apply_residual_connection_post_layernorm = args.apply_residual_connection_post_layernorm
...@@ -460,7 +459,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -460,7 +459,7 @@ class ParallelTransformerLayer(MegatronModule):
attn_mask_type=self_attn_mask_type) attn_mask_type=self_attn_mask_type)
self.hidden_dropout = args.hidden_dropout self.hidden_dropout = args.hidden_dropout
self.bias_dropout_fusion = args.bias_dropout_fusion self.bias_dropout_fusion = args.bias_dropout_fusion
self.drop_path = DropPath(drop_path_rate) self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else None
# Layernorm on the attention output # Layernorm on the attention output
self.post_attention_layernorm = LayerNorm( self.post_attention_layernorm = LayerNorm(
...@@ -504,7 +503,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -504,7 +503,7 @@ class ParallelTransformerLayer(MegatronModule):
else: else:
residual = hidden_states residual = hidden_states
if self.drop_path_rate == 0.0: if self.drop_path is None:
# jit scripting for a nn.module (with dropout) is not # jit scripting for a nn.module (with dropout) is not
# trigerring the fusion kernel. For now, we use two # trigerring the fusion kernel. For now, we use two
# different nn.functional routines to account for varying # different nn.functional routines to account for varying
...@@ -564,7 +563,7 @@ class ParallelTransformerLayer(MegatronModule): ...@@ -564,7 +563,7 @@ class ParallelTransformerLayer(MegatronModule):
else: else:
residual = layernorm_input residual = layernorm_input
if self.drop_path_rate == 0.0: if self.drop_path is None:
# re-enable torch grad to enable fused optimization. # re-enable torch grad to enable fused optimization.
with torch.enable_grad(): with torch.enable_grad():
output = bias_dropout_add_func( output = bias_dropout_add_func(
...@@ -608,7 +607,7 @@ class ParallelTransformer(MegatronModule): ...@@ -608,7 +607,7 @@ class ParallelTransformer(MegatronModule):
self.num_layers = mpu.get_num_layers( self.num_layers = mpu.get_num_layers(
args, args.model_type == ModelType.encoder_and_decoder) args, args.model_type == ModelType.encoder_and_decoder)
self.dpr = [x.item() for x in torch.linspace(0, self.drop_path_rate, args.num_layers)] self.drop_path_rates = [rate.item() for rate in torch.linspace(0, self.drop_path_rate, args.num_layers)]
# Transformer layers. # Transformer layers.
def build_layer(layer_number): def build_layer(layer_number):
...@@ -618,7 +617,7 @@ class ParallelTransformer(MegatronModule): ...@@ -618,7 +617,7 @@ class ParallelTransformer(MegatronModule):
layer_number, layer_number,
layer_type=layer_type, layer_type=layer_type,
self_attn_mask_type=self_attn_mask_type, self_attn_mask_type=self_attn_mask_type,
drop_path_rate=self.dpr[layer_number - 1]) drop_path_rate=self.drop_path_rates[layer_number - 1])
if args.virtual_pipeline_model_parallel_size is not None: if args.virtual_pipeline_model_parallel_size is not None:
assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \ assert args.num_layers % args.virtual_pipeline_model_parallel_size == 0, \
'num_layers_per_stage must be divisible by ' \ 'num_layers_per_stage must be divisible by ' \
......
...@@ -341,9 +341,9 @@ def get_learning_rate_scheduler(optimizer): ...@@ -341,9 +341,9 @@ def get_learning_rate_scheduler(optimizer):
warmup_steps=warmup_steps, warmup_steps=warmup_steps,
decay_steps=decay_steps, decay_steps=decay_steps,
decay_style=args.lr_decay_style, decay_style=args.lr_decay_style,
start_wd=args.start_wd, start_wd=args.start_weight_decay,
end_wd=args.end_wd, end_wd=args.end_weight_decay,
wd_incr_style=args.wd_incr_style, wd_incr_style=args.weight_decay_incr_style,
use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler, use_checkpoint_lr_scheduler=args.use_checkpoint_lr_scheduler,
override_lr_scheduler=args.override_lr_scheduler) override_lr_scheduler=args.override_lr_scheduler)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment