"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "a2874af2971d1b262371d9a6fae653662c4a5e95"
Commit 14c85e64 authored by Rewon Child's avatar Rewon Child
Browse files

Merge branch 'main' into rc-debug-underflow

parents 4e77e7c6 c1faa9fe
...@@ -202,7 +202,23 @@ def parse_args(extra_args_provider=None, defaults={}, ...@@ -202,7 +202,23 @@ def parse_args(extra_args_provider=None, defaults={},
assert args.checkpoint_activations, \ assert args.checkpoint_activations, \
'for distribute-checkpointed-activations to work you '\ 'for distribute-checkpointed-activations to work you '\
'need to enable checkpoint-activations' 'need to enable checkpoint-activations'
# custom kernel constraints check
seq_len = args.seq_length
attn_batch_size = \
(args.num_attention_heads / args.tensor_model_parallel_size) * \
args.micro_batch_size
# constraints on sequence length and attn_batch_size to enable warp based
# optimization and upper triangular optimization (for causal mask)
custom_kernel_constraint = seq_len > 16 and seq_len <=2048 and \
seq_len % 4 == 0 and attn_batch_size % 4 == 0
if args.fp16 and custom_kernel_constraint and args.masked_softmax_fusion:
print('WARNING: constraints for invoking optimized'
' fused softmax kernel are not met. We default back to unfused'
' kernel invocations.')
# Load scaled_masked_softmax_fusion_kernels # Load scaled_masked_softmax_fusion_kernels
if args.masked_softmax_fusion: if args.masked_softmax_fusion:
fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel() fused_kernels.load_scaled_upper_triang_masked_softmax_fusion_kernel()
...@@ -480,9 +496,9 @@ def _add_checkpointing_args(parser): ...@@ -480,9 +496,9 @@ def _add_checkpointing_args(parser):
help='Output directory to save checkpoints to.') help='Output directory to save checkpoints to.')
group.add_argument('--save-interval', type=int, default=None, group.add_argument('--save-interval', type=int, default=None,
help='Number of iterations between checkpoint saves.') help='Number of iterations between checkpoint saves.')
group.add_argument('--no-save-optim', action='store_true', group.add_argument('--no-save-optim', action='store_true', default=None,
help='Do not save current optimizer.') help='Do not save current optimizer.')
group.add_argument('--no-save-rng', action='store_true', group.add_argument('--no-save-rng', action='store_true', default=None,
help='Do not save current rng state.') help='Do not save current rng state.')
group.add_argument('--load', type=str, default=None, group.add_argument('--load', type=str, default=None,
help='Directory containing a model checkpoint.') help='Directory containing a model checkpoint.')
......
...@@ -343,12 +343,15 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True ...@@ -343,12 +343,15 @@ def load_checkpoint(model, optimizer, lr_scheduler, load_arg='load', strict=True
np.random.set_state(state_dict['np_rng_state']) np.random.set_state(state_dict['np_rng_state'])
torch.set_rng_state(state_dict['torch_rng_state']) torch.set_rng_state(state_dict['torch_rng_state'])
torch.cuda.set_rng_state(state_dict['cuda_rng_state']) torch.cuda.set_rng_state(state_dict['cuda_rng_state'])
# Check for empty states array
if not state_dict['rng_tracker_states']:
raise KeyError
mpu.get_cuda_rng_tracker().set_states( mpu.get_cuda_rng_tracker().set_states(
state_dict['rng_tracker_states']) state_dict['rng_tracker_states'])
except KeyError: except KeyError:
print_rank_0('Unable to load optimizer from checkpoint {}. ' print_rank_0('Unable to load rng state from checkpoint {}. '
'Specify --no-load-rng or --finetune to prevent ' 'Specify --no-load-rng or --finetune to prevent '
'attempting to load the optimizer state, ' 'attempting to load the rng state, '
'exiting ...'.format(checkpoint_name)) 'exiting ...'.format(checkpoint_name))
sys.exit() sys.exit()
......
...@@ -113,18 +113,23 @@ class FusedScaleMaskSoftmax(torch.nn.Module): ...@@ -113,18 +113,23 @@ class FusedScaleMaskSoftmax(torch.nn.Module):
assert ( assert (
self.scale is None or softmax_in_fp32 self.scale is None or softmax_in_fp32
), "softmax should be in fp32 when scaled" ), "softmax should be in fp32 when scaled"
def forward(self, input, mask): def forward(self, input, mask):
# [b, np, sq, sk] # [b, np, sq, sk]
assert input.dim() == 4
data_size = input.size() data_size = input.size()
query_seq_len = data_size[-2] query_seq_len = data_size[-2]
key_seq_len = data_size[-1] key_seq_len = data_size[-1]
assert input.dim() == 4 attn_batch_size = data_size[0] * data_size[1]
# invoke custom kernel # constraints on various tensor dimensions to enable warp based
if self.input_in_fp16 and key_seq_len <= 2048 and mask is not None and \ # optimization and upper triangular optimization (for causal mask)
query_seq_len % 4 == 0 and self.scaled_masked_softmax_fusion: custom_kernel_constraint = key_seq_len > 16 and key_seq_len <= 2048 and \
query_seq_len % 4 == 0 and attn_batch_size % 4 == 0
# invoke custom kernel
if self.input_in_fp16 and mask is not None and \
custom_kernel_constraint and self.scaled_masked_softmax_fusion:
scale = self.scale if self.scale is not None else 1.0 scale = self.scale if self.scale is not None else 1.0
if self.attn_mask_type == AttnMaskType.causal: if self.attn_mask_type == AttnMaskType.causal:
......
...@@ -351,6 +351,8 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward) ...@@ -351,6 +351,8 @@ def communicate(tensor_send_next, tensor_send_prev, recv_forward, recv_backward)
reqs = torch.distributed.batch_isend_irecv(ops) reqs = torch.distributed.batch_isend_irecv(ops)
for req in reqs: for req in reqs:
req.wait() req.wait()
# Temporary workaround for batch_isend_irecv() race condition.
torch.cuda.synchronize()
return tensor_recv_prev, tensor_recv_next return tensor_recv_prev, tensor_recv_next
......
...@@ -92,7 +92,9 @@ def main(): ...@@ -92,7 +92,9 @@ def main():
"""Main program.""" """Main program."""
initialize_megatron(extra_args_provider=add_text_generate_args, initialize_megatron(extra_args_provider=add_text_generate_args,
args_defaults={'tokenizer_type': 'GPT2BPETokenizer'}) args_defaults={'tokenizer_type': 'GPT2BPETokenizer',
'no_load_rng': True,
'no_load_optim': True})
# Set up model and load checkpoint. # Set up model and load checkpoint.
model = get_model(model_provider) model = get_model(model_provider)
......
...@@ -200,6 +200,8 @@ def main(): ...@@ -200,6 +200,8 @@ def main():
'micro_batch_size': 1, 'micro_batch_size': 1,
'no_load_optim': True, 'no_load_optim': True,
'no_load_rng': True, 'no_load_rng': True,
'no_save_optim': True,
'no_save_rng': True,
'save_interval': 1}) 'save_interval': 1})
args = get_args() args = get_args()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment