Commit 2c63b5cd authored by wangxj's avatar wangxj
Browse files

升级0.12版本

parent c271aaae
Pipeline #2451 passed with stage
......@@ -59,7 +59,7 @@ def validate_yaml(args, defaults={}):
(args.world_size // args.model_parallel.tensor_model_parallel_size))
args.model_parallel.transformer_pipeline_model_parallel_size = (
args.model_parallel.pipeline_model_parallel_size - 1
if args.standalone_embedding_stage else
if args.account_for_embedding_in_pipeline_split else
args.model_parallel.pipeline_model_parallel_size
)
# Checks.
......
......@@ -154,7 +154,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
get_blend_from_list(args.valid_data_path),
get_blend_from_list(args.test_data_path)
],
renormalize_blend_weights=args.renormalize_blend_weights,
split=args.split,
path_to_cache=args.data_cache_path,
tokenizer=tokenizer,
......
......@@ -35,8 +35,7 @@ from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec,
)
import torch._dynamo
torch._dynamo.config.suppress_errors = True
stimer = StragglerDetector()
......@@ -64,6 +63,15 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
# record stack information for the trace events
trace_alloc_record_context=True)
def oom_observer(device, alloc, device_alloc, device_free):
# snapshot right after an OOM happened
print('saving allocated state during OOM')
snapshot = torch.cuda.memory._snapshot()
from pickle import dump
dump(snapshot, open(f"oom_rank-{torch.distributed.get_rank()}_{args.memory_snapshot_path}", 'wb'))
torch._C._cuda_attach_out_of_memory_observer(oom_observer)
print_rank_0('building GPT model ...')
# Experimental loading arguments from yaml
if args.yaml_cfg is not None:
......@@ -91,11 +99,11 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
if use_te:
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
args.num_experts, args.moe_grouped_gemm,
args.qk_layernorm, args.multi_latent_attention, args.fp8)
args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm)
else:
transformer_layer_spec = get_gpt_layer_local_spec(
args.num_experts, args.moe_grouped_gemm,
args.qk_layernorm, args.multi_latent_attention)
args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm)
build_model_context = nullcontext
build_model_context_args = {}
......@@ -128,7 +136,9 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
rotary_base=args.rotary_base,
rope_scaling=args.use_rope_scaling
)
# model = torch.compile(model, mode="max-autotune-no-cudagraphs")
print_rank_0(model)
return model
......@@ -148,8 +158,8 @@ def get_batch(data_iterator):
return batch.values()
# define spiky loss as a variation of 20% or more
SPIKY_LOSS_PERC = 0.2
# define spiky loss as a loss that's 10x the max loss observed
SPIKY_LOSS_FACTOR = 10
def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
......@@ -185,11 +195,22 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
tolerance=0.0, # forward pass calculations are determinisic
fatal=True,
)
rerun_state_machine.validate_result(
result=loss[0],
rejection_func=torch.isinf,
message="found Inf in local forward loss calculation",
tolerance=0.0, # forward pass calculations are determinisic
fatal=True,
)
# Check for spiky loss
if args.check_for_spiky_loss:
rerun_state_machine.validate_result(
result=loss[0],
rejection_func=partial(rerun_state_machine.is_spiky_loss, threshold=SPIKY_LOSS_PERC),
rejection_func=partial(
rerun_state_machine.is_unexpectedly_large,
threshold=SPIKY_LOSS_FACTOR,
context="loss",
),
message="Spiky loss",
tolerance=0.0, # forward pass calculations are determinisic
fatal=False,
......@@ -250,7 +271,6 @@ def core_gpt_dataset_config_from_args(args):
sequence_length=args.seq_length,
blend=blend,
blend_per_split=blend_per_split,
renormalize_blend_weights=args.renormalize_blend_weights,
split=args.split,
num_dataset_builder_threads=args.num_dataset_builder_threads,
path_to_cache=args.data_cache_path,
......
File mode changed from 100755 to 100644
......@@ -104,8 +104,8 @@ def get_batch(data_iterator):
return batch.values()
# define spiky loss as a variation of 20% or more
SPIKY_LOSS_PERC = 0.2
# define spiky loss as a loss that's 10x the max loss observed
SPIKY_LOSS_FACTOR = 10
def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
......@@ -141,11 +141,22 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
tolerance=0.0, # forward pass calculations are determinisic
fatal=True,
)
rerun_state_machine.validate_result(
result=loss[0],
rejection_func=torch.isinf,
message="found Inf in local forward loss calculation",
tolerance=0.0, # forward pass calculations are determinisic
fatal=True,
)
# Check for spiky loss
if args.check_for_spiky_loss:
rerun_state_machine.validate_result(
result=loss[0],
rejection_func=partial(rerun_state_machine.is_spiky_loss, threshold=SPIKY_LOSS_PERC),
rejection_func=partial(
rerun_state_machine.is_unexpectedly_large,
threshold=SPIKY_LOSS_FACTOR,
context="loss",
),
message="Spiky loss",
tolerance=0.0, # forward pass calculations are determinisic
fatal=False,
......@@ -207,7 +218,6 @@ def core_gpt_dataset_config_from_args(args):
sequence_length=args.seq_length,
blend=blend,
blend_per_split=blend_per_split,
renormalize_blend_weights=args.renormalize_blend_weights,
split=args.split,
num_dataset_builder_threads=args.num_dataset_builder_threads,
path_to_cache=args.data_cache_path,
......
......@@ -189,7 +189,6 @@ def train_valid_test_datasets_provider(train_valid_test_num_samples):
get_blend_from_list(args.valid_data_path),
get_blend_from_list(args.test_data_path)
],
renormalize_blend_weights=args.renormalize_blend_weights,
split=args.split,
split_preprocessing=retro_config.retro_split_preprocessing,
path_to_cache=args.data_cache_path,
......
......@@ -141,6 +141,8 @@ def model_provider(
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent,
relative_attention_num_buckets=args.relative_attention_num_buckets,
relative_attention_max_distance=args.relative_attention_max_distance,
add_encoder=add_encoder,
add_decoder=add_decoder,
)
......@@ -226,7 +228,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples: int):
get_blend_from_list(args.valid_data_path),
get_blend_from_list(args.test_data_path),
],
renormalize_blend_weights=args.renormalize_blend_weights,
split=args.split,
path_to_cache=args.data_cache_path,
tokenizer=tokenizer,
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
......@@ -22,42 +22,13 @@ from megatron.core.models.vision.vit_layer_specs import (
get_vit_layer_with_local_spec,
)
from megatron.core.transformer.spec_utils import import_module
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.training import get_args, get_timers, get_tokenizer, pretrain, print_rank_0
from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.utils import get_batch_on_this_cp_rank
from megatron.core import mpu
from megatron.core.models.multimodal import context_parallel
from pretrain_gpt import loss_func
def calculate_model_parallel_padding(decoder_seq_len, text_only=False):
args = get_args()
cp_size = args.context_parallel_size
tp_size = args.tensor_model_parallel_size
mp_padding_needed = 0
# TP Comm overlap is performed with combined text+image embeddings.
# text_only flag skips using the full sequence length to calculate padding and uses
# the provided decoder_seq_len
if args.sequence_parallel and args.decoder_tp_comm_overlap and not text_only:
# If TP Comm Overlap is enabled for combined text+image embedding in LM backbone,
# user needs to provide decoder_seq_length with any potential padding needed for SP+CP
assert args.decoder_seq_length is not None, \
"Please provide --decoder-seq-length when using TP Comm overlap for LM backbone"
mp_padding_needed = args.decoder_seq_length - decoder_seq_len
elif args.sequence_parallel or cp_size > 1:
if args.sequence_parallel and cp_size > 1:
# Padding to multiple of tp_size * cp_size*2 when using sequence parallel and context parallel
padding_factor = tp_size * cp_size * 2
elif cp_size > 1:
padding_factor = cp_size * 2
elif args.sequence_parallel:
padding_factor = tp_size
mp_padding_needed = int((decoder_seq_len + padding_factor - 1) // (padding_factor) * (padding_factor)) - decoder_seq_len
args.decoder_seq_length = decoder_seq_len + mp_padding_needed
else:
args.decoder_seq_length = decoder_seq_len
return mp_padding_needed
def model_provider(
pre_process=True, post_process=True, add_encoder=True, add_decoder=True, parallel_output=True
......@@ -82,12 +53,8 @@ def model_provider(
vision_model_type = "clip"
assert args.ckpt_format == 'torch', "Only ckpt-format torch is supported for VLM training currently."
if args.pipeline_model_parallel_size > 1:
assert not args.freeze_LM, "Freezing a pipeline parallel language model is not currently supported"
if args.encoder_pipeline_model_parallel_size == 1:
assert not args.freeze_ViT, "Freezing a vision encoder on its own pipeline rank is not currently supported"
assert not (args.context_parallel_size > 1 and args.pipeline_model_parallel_size > 1), "PP+CP is not yet supported by this script. \
Current mock dataset does not support natively packed sequence dataset required for correct PP comm shapes."
num_image_embeddings = get_num_image_embeddings(
args.img_h, args.img_w, args.patch_dim, vision_model_type, args.disable_vision_class_token,
......@@ -108,7 +75,15 @@ def model_provider(
warnings.warn(
f"Changed seq_length and encoder_seq_length (vision model sequence length) from {old_seq_length} to num_image_tokens ({num_image_embeddings})"
)
mp_padding_needed = calculate_model_parallel_padding(decoder_seq_len)
mp_padding_needed = context_parallel.get_padding(
decoder_seq_len,
args.context_parallel_size,
args.tensor_model_parallel_size,
args.sequence_parallel,
args.decoder_tp_comm_overlap,
args.decoder_seq_length
)
args.decoder_seq_length = decoder_seq_len + mp_padding_needed
args.max_position_embeddings = max(args.max_position_embeddings, args.decoder_seq_length)
......@@ -129,7 +104,7 @@ def model_provider(
language_transformer_layer_spec = decoder_model_with_local_default_spec(
args.num_experts, args.moe_grouped_gemm
)
# Prepare mask type for any required padding to support CP/SP sequence sharding.
if mp_padding_needed > 0:
if language_transformer_layer_spec.submodules.self_attention.params.get('attn_mask_type', '') == AttnMaskType.causal:
......@@ -189,11 +164,16 @@ def model_provider(
if args.virtual_pipeline_model_parallel_size:
raise NotImplementedError("virtual pipeline model parallelism is not supported yet.")
language_max_sequence_length = args.decoder_seq_length
if args.context_parallel_size > 1:
if args.use_packed_sequence or mp_padding_needed > 0:
# Use THD data format
language_max_sequence_length = args.decoder_seq_length * args.micro_batch_size
model = LLaVAModel(
language_transformer_config=language_transformer_config,
language_transformer_layer_spec=language_transformer_layer_spec,
language_vocab_size=args.padded_vocab_size,
language_max_sequence_length=args.decoder_seq_length,
language_max_sequence_length=language_max_sequence_length,
vision_transformer_config=vision_transformer_config,
vision_transformer_layer_spec=vision_transformer_layer_spec,
drop_vision_class_token=args.disable_vision_class_token,
......@@ -295,6 +275,7 @@ def _preprocess_data_for_llava(data):
return data
def get_batch(data_iterator):
"""Generate a batch.
......@@ -304,33 +285,6 @@ def get_batch(data_iterator):
Returns:
sample: A data sample with images, tokens, etc.
"""
def _get_packed_seq_params(tokens, img_seq_len, mp_padding_needed):
batch_size = tokens.shape[0]
# Calculate the valid token seq len that LM backbone should compute on
combined_valid_seqlen = tokens.shape[1] + img_seq_len - mp_padding_needed
cu_seqlens = torch.arange(
0, (batch_size + 1) * (combined_valid_seqlen), step=(combined_valid_seqlen), dtype=torch.int32, device=tokens.device)
# Calculate the total padded token seq len
combined_padded_seqlen = tokens.shape[1] + img_seq_len
cu_seqlens_padded = None
qkv_format = 'sbhd'
if cp_size > 1:
# Provide cu_seqlens_<q/kv>_padded for CP support
cu_seqlens_padded = torch.arange(
0, (batch_size + 1) * (combined_padded_seqlen), step=(combined_padded_seqlen), dtype=torch.int32, device=tokens.device)
# CP with padding mask type requires THD format
qkv_format = 'thd'
packed_seq_params = PackedSeqParams(
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens,
cu_seqlens_q_padded=cu_seqlens_padded,
cu_seqlens_kv_padded=cu_seqlens_padded,
max_seqlen_q=combined_padded_seqlen,
max_seqlen_kv=combined_padded_seqlen,
qkv_format=qkv_format,
)
return packed_seq_params
args = get_args()
cp_size = args.context_parallel_size
# Broadcast data.
......@@ -351,28 +305,49 @@ def get_batch(data_iterator):
labels = data_i["labels"].long()
loss_mask = data_f["loss_mask"].float()
images = data_f["image"].float()
if cp_size > 1 or args.sequence_parallel:
vision_model_type = "clip"
# Calculate the number of image embedding tokens will be added to text tokens
# Calculate the number of image embedding tokens will be added to text tokens
num_image_embeddings_per_tile = get_num_image_embeddings(
args.img_h, args.img_w, args.patch_dim, vision_model_type, args.disable_vision_class_token, 1
)
# Pad to make sure the text sequence can be sharded equally by CP chunks.
mp_padding_needed_for_text = calculate_model_parallel_padding(tokens.shape[1], text_only=True)
if mp_padding_needed_for_text > 0:
tokens, position_ids, labels, loss_mask = [torch.nn.functional.pad(item, (0, mp_padding_needed_for_text)) for item in (tokens, position_ids, labels, loss_mask)]
# Image token mask must be supplied before distributed sequence to CP ranks.
image_token_mask = tokens == DEFAULT_IMAGE_TOKEN_INDEX
num_images_per_sample = torch.sum(image_token_mask, dim=-1)
img_seq_len = (num_image_embeddings_per_tile * num_images_per_sample - num_images_per_sample).max()
packed_seq_params = _get_packed_seq_params(tokens, img_seq_len, mp_padding_needed_for_text)
# slice batch along sequence dimension for context parallelism
batch = get_batch_on_this_cp_rank({"tokens": tokens, "position_ids": position_ids})
mp_padding_needed_for_text = context_parallel.get_padding(
tokens.shape[1] + img_seq_len,
args.context_parallel_size,
args.tensor_model_parallel_size,
args.sequence_parallel,
args.decoder_tp_comm_overlap,
args.decoder_seq_length
)
if mp_padding_needed_for_text > 0:
tokens, position_ids, labels, loss_mask = [torch.nn.functional.pad(item, (0, mp_padding_needed_for_text)) for item in (tokens, position_ids, labels, loss_mask)]
packed_seq_params = context_parallel.get_packed_seq_params(tokens, img_seq_len, mp_padding_needed_for_text, cp_size, args.use_packed_sequence)
if packed_seq_params.qkv_format == 'thd':
# Reshape from [B,S] to [T,1]
tokens = (
tokens.contiguous()
.view(tokens.shape[0] * tokens.shape[1])
.unsqueeze(0)
)
position_ids = (
position_ids.contiguous()
.view(position_ids.shape[0] * position_ids.shape[1])
.unsqueeze(0)
)
labels = labels.view(labels.shape[0] * labels.shape[1]).unsqueeze(0)
loss_mask = loss_mask.view(
loss_mask.shape[0] * loss_mask.shape[1]
).unsqueeze(0)
attention_mask = None # Use the attention mask type defined in layer spec. Typically no mask for the vision model and causal mask for the vision model.
return batch["tokens"], batch["position_ids"], labels, images, loss_mask, attention_mask, image_token_mask, packed_seq_params
return tokens, position_ids, labels, images, loss_mask, attention_mask, packed_seq_params
def forward_step(data_iterator, model: LLaVAModel):
......@@ -390,11 +365,11 @@ def forward_step(data_iterator, model: LLaVAModel):
# Get the batch.
timers('batch-generator', log_level=2).start()
tokens, position_ids, labels, images, loss_mask, attention_mask, image_token_mask, packed_seq_params = get_batch(data_iterator)
tokens, position_ids, labels, images, loss_mask, attention_mask, packed_seq_params = get_batch(data_iterator)
timers('batch-generator').stop()
output_tensor, loss_mask = model(
images, tokens, position_ids, attention_mask, labels, loss_mask, image_token_mask=image_token_mask, packed_seq_params=packed_seq_params
images, tokens, position_ids, attention_mask, labels, loss_mask, packed_seq_params=packed_seq_params
)
return output_tensor, partial(loss_func, loss_mask)
......@@ -419,6 +394,12 @@ def add_vlm_extra_args(parser):
group.add_argument("--decoder-tp-comm-overlap", action="store_true", default=False, help="Enables the overlap of "
"Tensor parallel communication and GEMM kernels in Decoder only. "
"Please provide decoder-seq-length when using this feature.")
group.add_argument(
"--use-packed-sequence",
action="store_true",
default=False,
help="Use packed sequence",
)
return parser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment