Commit 2c63b5cd authored by wangxj's avatar wangxj
Browse files

升级0.12版本

parent c271aaae
Pipeline #2451 passed with stage
...@@ -59,7 +59,7 @@ def validate_yaml(args, defaults={}): ...@@ -59,7 +59,7 @@ def validate_yaml(args, defaults={}):
(args.world_size // args.model_parallel.tensor_model_parallel_size)) (args.world_size // args.model_parallel.tensor_model_parallel_size))
args.model_parallel.transformer_pipeline_model_parallel_size = ( args.model_parallel.transformer_pipeline_model_parallel_size = (
args.model_parallel.pipeline_model_parallel_size - 1 args.model_parallel.pipeline_model_parallel_size - 1
if args.standalone_embedding_stage else if args.account_for_embedding_in_pipeline_split else
args.model_parallel.pipeline_model_parallel_size args.model_parallel.pipeline_model_parallel_size
) )
# Checks. # Checks.
......
...@@ -154,7 +154,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples): ...@@ -154,7 +154,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples):
get_blend_from_list(args.valid_data_path), get_blend_from_list(args.valid_data_path),
get_blend_from_list(args.test_data_path) get_blend_from_list(args.test_data_path)
], ],
renormalize_blend_weights=args.renormalize_blend_weights,
split=args.split, split=args.split,
path_to_cache=args.data_cache_path, path_to_cache=args.data_cache_path,
tokenizer=tokenizer, tokenizer=tokenizer,
......
...@@ -35,8 +35,7 @@ from megatron.core.models.gpt.gpt_layer_specs import ( ...@@ -35,8 +35,7 @@ from megatron.core.models.gpt.gpt_layer_specs import (
get_gpt_layer_local_spec, get_gpt_layer_local_spec,
get_gpt_layer_with_transformer_engine_spec, get_gpt_layer_with_transformer_engine_spec,
) )
import torch._dynamo
torch._dynamo.config.suppress_errors = True
stimer = StragglerDetector() stimer = StragglerDetector()
...@@ -64,6 +63,15 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat ...@@ -64,6 +63,15 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
# record stack information for the trace events # record stack information for the trace events
trace_alloc_record_context=True) trace_alloc_record_context=True)
def oom_observer(device, alloc, device_alloc, device_free):
# snapshot right after an OOM happened
print('saving allocated state during OOM')
snapshot = torch.cuda.memory._snapshot()
from pickle import dump
dump(snapshot, open(f"oom_rank-{torch.distributed.get_rank()}_{args.memory_snapshot_path}", 'wb'))
torch._C._cuda_attach_out_of_memory_observer(oom_observer)
print_rank_0('building GPT model ...') print_rank_0('building GPT model ...')
# Experimental loading arguments from yaml # Experimental loading arguments from yaml
if args.yaml_cfg is not None: if args.yaml_cfg is not None:
...@@ -91,11 +99,11 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat ...@@ -91,11 +99,11 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
if use_te: if use_te:
transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec( transformer_layer_spec = get_gpt_layer_with_transformer_engine_spec(
args.num_experts, args.moe_grouped_gemm, args.num_experts, args.moe_grouped_gemm,
args.qk_layernorm, args.multi_latent_attention, args.fp8) args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm)
else: else:
transformer_layer_spec = get_gpt_layer_local_spec( transformer_layer_spec = get_gpt_layer_local_spec(
args.num_experts, args.moe_grouped_gemm, args.num_experts, args.moe_grouped_gemm,
args.qk_layernorm, args.multi_latent_attention) args.qk_layernorm, args.multi_latent_attention, args.moe_use_legacy_grouped_gemm)
build_model_context = nullcontext build_model_context = nullcontext
build_model_context_args = {} build_model_context_args = {}
...@@ -128,7 +136,9 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat ...@@ -128,7 +136,9 @@ def model_provider(pre_process=True, post_process=True) -> Union[GPTModel, megat
rotary_base=args.rotary_base, rotary_base=args.rotary_base,
rope_scaling=args.use_rope_scaling rope_scaling=args.use_rope_scaling
) )
# model = torch.compile(model, mode="max-autotune-no-cudagraphs")
print_rank_0(model)
return model return model
...@@ -148,8 +158,8 @@ def get_batch(data_iterator): ...@@ -148,8 +158,8 @@ def get_batch(data_iterator):
return batch.values() return batch.values()
# define spiky loss as a variation of 20% or more # define spiky loss as a loss that's 10x the max loss observed
SPIKY_LOSS_PERC = 0.2 SPIKY_LOSS_FACTOR = 10
def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
...@@ -185,11 +195,22 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): ...@@ -185,11 +195,22 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
tolerance=0.0, # forward pass calculations are determinisic tolerance=0.0, # forward pass calculations are determinisic
fatal=True, fatal=True,
) )
rerun_state_machine.validate_result(
result=loss[0],
rejection_func=torch.isinf,
message="found Inf in local forward loss calculation",
tolerance=0.0, # forward pass calculations are determinisic
fatal=True,
)
# Check for spiky loss # Check for spiky loss
if args.check_for_spiky_loss: if args.check_for_spiky_loss:
rerun_state_machine.validate_result( rerun_state_machine.validate_result(
result=loss[0], result=loss[0],
rejection_func=partial(rerun_state_machine.is_spiky_loss, threshold=SPIKY_LOSS_PERC), rejection_func=partial(
rerun_state_machine.is_unexpectedly_large,
threshold=SPIKY_LOSS_FACTOR,
context="loss",
),
message="Spiky loss", message="Spiky loss",
tolerance=0.0, # forward pass calculations are determinisic tolerance=0.0, # forward pass calculations are determinisic
fatal=False, fatal=False,
...@@ -250,7 +271,6 @@ def core_gpt_dataset_config_from_args(args): ...@@ -250,7 +271,6 @@ def core_gpt_dataset_config_from_args(args):
sequence_length=args.seq_length, sequence_length=args.seq_length,
blend=blend, blend=blend,
blend_per_split=blend_per_split, blend_per_split=blend_per_split,
renormalize_blend_weights=args.renormalize_blend_weights,
split=args.split, split=args.split,
num_dataset_builder_threads=args.num_dataset_builder_threads, num_dataset_builder_threads=args.num_dataset_builder_threads,
path_to_cache=args.data_cache_path, path_to_cache=args.data_cache_path,
......
File mode changed from 100755 to 100644
...@@ -104,8 +104,8 @@ def get_batch(data_iterator): ...@@ -104,8 +104,8 @@ def get_batch(data_iterator):
return batch.values() return batch.values()
# define spiky loss as a variation of 20% or more # define spiky loss as a loss that's 10x the max loss observed
SPIKY_LOSS_PERC = 0.2 SPIKY_LOSS_FACTOR = 10
def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
...@@ -141,11 +141,22 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor): ...@@ -141,11 +141,22 @@ def loss_func(loss_mask: torch.Tensor, output_tensor: torch.Tensor):
tolerance=0.0, # forward pass calculations are determinisic tolerance=0.0, # forward pass calculations are determinisic
fatal=True, fatal=True,
) )
rerun_state_machine.validate_result(
result=loss[0],
rejection_func=torch.isinf,
message="found Inf in local forward loss calculation",
tolerance=0.0, # forward pass calculations are determinisic
fatal=True,
)
# Check for spiky loss # Check for spiky loss
if args.check_for_spiky_loss: if args.check_for_spiky_loss:
rerun_state_machine.validate_result( rerun_state_machine.validate_result(
result=loss[0], result=loss[0],
rejection_func=partial(rerun_state_machine.is_spiky_loss, threshold=SPIKY_LOSS_PERC), rejection_func=partial(
rerun_state_machine.is_unexpectedly_large,
threshold=SPIKY_LOSS_FACTOR,
context="loss",
),
message="Spiky loss", message="Spiky loss",
tolerance=0.0, # forward pass calculations are determinisic tolerance=0.0, # forward pass calculations are determinisic
fatal=False, fatal=False,
...@@ -207,7 +218,6 @@ def core_gpt_dataset_config_from_args(args): ...@@ -207,7 +218,6 @@ def core_gpt_dataset_config_from_args(args):
sequence_length=args.seq_length, sequence_length=args.seq_length,
blend=blend, blend=blend,
blend_per_split=blend_per_split, blend_per_split=blend_per_split,
renormalize_blend_weights=args.renormalize_blend_weights,
split=args.split, split=args.split,
num_dataset_builder_threads=args.num_dataset_builder_threads, num_dataset_builder_threads=args.num_dataset_builder_threads,
path_to_cache=args.data_cache_path, path_to_cache=args.data_cache_path,
......
...@@ -189,7 +189,6 @@ def train_valid_test_datasets_provider(train_valid_test_num_samples): ...@@ -189,7 +189,6 @@ def train_valid_test_datasets_provider(train_valid_test_num_samples):
get_blend_from_list(args.valid_data_path), get_blend_from_list(args.valid_data_path),
get_blend_from_list(args.test_data_path) get_blend_from_list(args.test_data_path)
], ],
renormalize_blend_weights=args.renormalize_blend_weights,
split=args.split, split=args.split,
split_preprocessing=retro_config.retro_split_preprocessing, split_preprocessing=retro_config.retro_split_preprocessing,
path_to_cache=args.data_cache_path, path_to_cache=args.data_cache_path,
......
...@@ -141,6 +141,8 @@ def model_provider( ...@@ -141,6 +141,8 @@ def model_provider(
share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights, share_embeddings_and_output_weights=not args.untie_embeddings_and_output_weights,
position_embedding_type=args.position_embedding_type, position_embedding_type=args.position_embedding_type,
rotary_percent=args.rotary_percent, rotary_percent=args.rotary_percent,
relative_attention_num_buckets=args.relative_attention_num_buckets,
relative_attention_max_distance=args.relative_attention_max_distance,
add_encoder=add_encoder, add_encoder=add_encoder,
add_decoder=add_decoder, add_decoder=add_decoder,
) )
...@@ -226,7 +228,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples: int): ...@@ -226,7 +228,6 @@ def train_valid_test_datasets_provider(train_val_test_num_samples: int):
get_blend_from_list(args.valid_data_path), get_blend_from_list(args.valid_data_path),
get_blend_from_list(args.test_data_path), get_blend_from_list(args.test_data_path),
], ],
renormalize_blend_weights=args.renormalize_blend_weights,
split=args.split, split=args.split,
path_to_cache=args.data_cache_path, path_to_cache=args.data_cache_path,
tokenizer=tokenizer, tokenizer=tokenizer,
......
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
File mode changed from 100755 to 100644
...@@ -22,42 +22,13 @@ from megatron.core.models.vision.vit_layer_specs import ( ...@@ -22,42 +22,13 @@ from megatron.core.models.vision.vit_layer_specs import (
get_vit_layer_with_local_spec, get_vit_layer_with_local_spec,
) )
from megatron.core.transformer.spec_utils import import_module from megatron.core.transformer.spec_utils import import_module
from megatron.core.packed_seq_params import PackedSeqParams
from megatron.training import get_args, get_timers, get_tokenizer, pretrain, print_rank_0 from megatron.training import get_args, get_timers, get_tokenizer, pretrain, print_rank_0
from megatron.training.arguments import core_transformer_config_from_args from megatron.training.arguments import core_transformer_config_from_args
from megatron.training.utils import get_batch_on_this_cp_rank from megatron.training.utils import get_batch_on_this_cp_rank
from megatron.core import mpu from megatron.core import mpu
from megatron.core.models.multimodal import context_parallel
from pretrain_gpt import loss_func from pretrain_gpt import loss_func
def calculate_model_parallel_padding(decoder_seq_len, text_only=False):
args = get_args()
cp_size = args.context_parallel_size
tp_size = args.tensor_model_parallel_size
mp_padding_needed = 0
# TP Comm overlap is performed with combined text+image embeddings.
# text_only flag skips using the full sequence length to calculate padding and uses
# the provided decoder_seq_len
if args.sequence_parallel and args.decoder_tp_comm_overlap and not text_only:
# If TP Comm Overlap is enabled for combined text+image embedding in LM backbone,
# user needs to provide decoder_seq_length with any potential padding needed for SP+CP
assert args.decoder_seq_length is not None, \
"Please provide --decoder-seq-length when using TP Comm overlap for LM backbone"
mp_padding_needed = args.decoder_seq_length - decoder_seq_len
elif args.sequence_parallel or cp_size > 1:
if args.sequence_parallel and cp_size > 1:
# Padding to multiple of tp_size * cp_size*2 when using sequence parallel and context parallel
padding_factor = tp_size * cp_size * 2
elif cp_size > 1:
padding_factor = cp_size * 2
elif args.sequence_parallel:
padding_factor = tp_size
mp_padding_needed = int((decoder_seq_len + padding_factor - 1) // (padding_factor) * (padding_factor)) - decoder_seq_len
args.decoder_seq_length = decoder_seq_len + mp_padding_needed
else:
args.decoder_seq_length = decoder_seq_len
return mp_padding_needed
def model_provider( def model_provider(
pre_process=True, post_process=True, add_encoder=True, add_decoder=True, parallel_output=True pre_process=True, post_process=True, add_encoder=True, add_decoder=True, parallel_output=True
...@@ -82,12 +53,8 @@ def model_provider( ...@@ -82,12 +53,8 @@ def model_provider(
vision_model_type = "clip" vision_model_type = "clip"
assert args.ckpt_format == 'torch', "Only ckpt-format torch is supported for VLM training currently." assert args.ckpt_format == 'torch', "Only ckpt-format torch is supported for VLM training currently."
assert not (args.context_parallel_size > 1 and args.pipeline_model_parallel_size > 1), "PP+CP is not yet supported by this script. \
if args.pipeline_model_parallel_size > 1: Current mock dataset does not support natively packed sequence dataset required for correct PP comm shapes."
assert not args.freeze_LM, "Freezing a pipeline parallel language model is not currently supported"
if args.encoder_pipeline_model_parallel_size == 1:
assert not args.freeze_ViT, "Freezing a vision encoder on its own pipeline rank is not currently supported"
num_image_embeddings = get_num_image_embeddings( num_image_embeddings = get_num_image_embeddings(
args.img_h, args.img_w, args.patch_dim, vision_model_type, args.disable_vision_class_token, args.img_h, args.img_w, args.patch_dim, vision_model_type, args.disable_vision_class_token,
...@@ -108,7 +75,15 @@ def model_provider( ...@@ -108,7 +75,15 @@ def model_provider(
warnings.warn( warnings.warn(
f"Changed seq_length and encoder_seq_length (vision model sequence length) from {old_seq_length} to num_image_tokens ({num_image_embeddings})" f"Changed seq_length and encoder_seq_length (vision model sequence length) from {old_seq_length} to num_image_tokens ({num_image_embeddings})"
) )
mp_padding_needed = calculate_model_parallel_padding(decoder_seq_len) mp_padding_needed = context_parallel.get_padding(
decoder_seq_len,
args.context_parallel_size,
args.tensor_model_parallel_size,
args.sequence_parallel,
args.decoder_tp_comm_overlap,
args.decoder_seq_length
)
args.decoder_seq_length = decoder_seq_len + mp_padding_needed
args.max_position_embeddings = max(args.max_position_embeddings, args.decoder_seq_length) args.max_position_embeddings = max(args.max_position_embeddings, args.decoder_seq_length)
...@@ -129,7 +104,7 @@ def model_provider( ...@@ -129,7 +104,7 @@ def model_provider(
language_transformer_layer_spec = decoder_model_with_local_default_spec( language_transformer_layer_spec = decoder_model_with_local_default_spec(
args.num_experts, args.moe_grouped_gemm args.num_experts, args.moe_grouped_gemm
) )
# Prepare mask type for any required padding to support CP/SP sequence sharding. # Prepare mask type for any required padding to support CP/SP sequence sharding.
if mp_padding_needed > 0: if mp_padding_needed > 0:
if language_transformer_layer_spec.submodules.self_attention.params.get('attn_mask_type', '') == AttnMaskType.causal: if language_transformer_layer_spec.submodules.self_attention.params.get('attn_mask_type', '') == AttnMaskType.causal:
...@@ -189,11 +164,16 @@ def model_provider( ...@@ -189,11 +164,16 @@ def model_provider(
if args.virtual_pipeline_model_parallel_size: if args.virtual_pipeline_model_parallel_size:
raise NotImplementedError("virtual pipeline model parallelism is not supported yet.") raise NotImplementedError("virtual pipeline model parallelism is not supported yet.")
language_max_sequence_length = args.decoder_seq_length
if args.context_parallel_size > 1:
if args.use_packed_sequence or mp_padding_needed > 0:
# Use THD data format
language_max_sequence_length = args.decoder_seq_length * args.micro_batch_size
model = LLaVAModel( model = LLaVAModel(
language_transformer_config=language_transformer_config, language_transformer_config=language_transformer_config,
language_transformer_layer_spec=language_transformer_layer_spec, language_transformer_layer_spec=language_transformer_layer_spec,
language_vocab_size=args.padded_vocab_size, language_vocab_size=args.padded_vocab_size,
language_max_sequence_length=args.decoder_seq_length, language_max_sequence_length=language_max_sequence_length,
vision_transformer_config=vision_transformer_config, vision_transformer_config=vision_transformer_config,
vision_transformer_layer_spec=vision_transformer_layer_spec, vision_transformer_layer_spec=vision_transformer_layer_spec,
drop_vision_class_token=args.disable_vision_class_token, drop_vision_class_token=args.disable_vision_class_token,
...@@ -295,6 +275,7 @@ def _preprocess_data_for_llava(data): ...@@ -295,6 +275,7 @@ def _preprocess_data_for_llava(data):
return data return data
def get_batch(data_iterator): def get_batch(data_iterator):
"""Generate a batch. """Generate a batch.
...@@ -304,33 +285,6 @@ def get_batch(data_iterator): ...@@ -304,33 +285,6 @@ def get_batch(data_iterator):
Returns: Returns:
sample: A data sample with images, tokens, etc. sample: A data sample with images, tokens, etc.
""" """
def _get_packed_seq_params(tokens, img_seq_len, mp_padding_needed):
batch_size = tokens.shape[0]
# Calculate the valid token seq len that LM backbone should compute on
combined_valid_seqlen = tokens.shape[1] + img_seq_len - mp_padding_needed
cu_seqlens = torch.arange(
0, (batch_size + 1) * (combined_valid_seqlen), step=(combined_valid_seqlen), dtype=torch.int32, device=tokens.device)
# Calculate the total padded token seq len
combined_padded_seqlen = tokens.shape[1] + img_seq_len
cu_seqlens_padded = None
qkv_format = 'sbhd'
if cp_size > 1:
# Provide cu_seqlens_<q/kv>_padded for CP support
cu_seqlens_padded = torch.arange(
0, (batch_size + 1) * (combined_padded_seqlen), step=(combined_padded_seqlen), dtype=torch.int32, device=tokens.device)
# CP with padding mask type requires THD format
qkv_format = 'thd'
packed_seq_params = PackedSeqParams(
cu_seqlens_q=cu_seqlens,
cu_seqlens_kv=cu_seqlens,
cu_seqlens_q_padded=cu_seqlens_padded,
cu_seqlens_kv_padded=cu_seqlens_padded,
max_seqlen_q=combined_padded_seqlen,
max_seqlen_kv=combined_padded_seqlen,
qkv_format=qkv_format,
)
return packed_seq_params
args = get_args() args = get_args()
cp_size = args.context_parallel_size cp_size = args.context_parallel_size
# Broadcast data. # Broadcast data.
...@@ -351,28 +305,49 @@ def get_batch(data_iterator): ...@@ -351,28 +305,49 @@ def get_batch(data_iterator):
labels = data_i["labels"].long() labels = data_i["labels"].long()
loss_mask = data_f["loss_mask"].float() loss_mask = data_f["loss_mask"].float()
images = data_f["image"].float() images = data_f["image"].float()
if cp_size > 1 or args.sequence_parallel: if cp_size > 1 or args.sequence_parallel:
vision_model_type = "clip" vision_model_type = "clip"
# Calculate the number of image embedding tokens will be added to text tokens # Calculate the number of image embedding tokens will be added to text tokens
num_image_embeddings_per_tile = get_num_image_embeddings( num_image_embeddings_per_tile = get_num_image_embeddings(
args.img_h, args.img_w, args.patch_dim, vision_model_type, args.disable_vision_class_token, 1 args.img_h, args.img_w, args.patch_dim, vision_model_type, args.disable_vision_class_token, 1
) )
# Pad to make sure the text sequence can be sharded equally by CP chunks. # Pad to make sure the text sequence can be sharded equally by CP chunks.
mp_padding_needed_for_text = calculate_model_parallel_padding(tokens.shape[1], text_only=True)
if mp_padding_needed_for_text > 0:
tokens, position_ids, labels, loss_mask = [torch.nn.functional.pad(item, (0, mp_padding_needed_for_text)) for item in (tokens, position_ids, labels, loss_mask)]
# Image token mask must be supplied before distributed sequence to CP ranks.
image_token_mask = tokens == DEFAULT_IMAGE_TOKEN_INDEX image_token_mask = tokens == DEFAULT_IMAGE_TOKEN_INDEX
num_images_per_sample = torch.sum(image_token_mask, dim=-1) num_images_per_sample = torch.sum(image_token_mask, dim=-1)
img_seq_len = (num_image_embeddings_per_tile * num_images_per_sample - num_images_per_sample).max() img_seq_len = (num_image_embeddings_per_tile * num_images_per_sample - num_images_per_sample).max()
packed_seq_params = _get_packed_seq_params(tokens, img_seq_len, mp_padding_needed_for_text) mp_padding_needed_for_text = context_parallel.get_padding(
tokens.shape[1] + img_seq_len,
# slice batch along sequence dimension for context parallelism args.context_parallel_size,
batch = get_batch_on_this_cp_rank({"tokens": tokens, "position_ids": position_ids}) args.tensor_model_parallel_size,
args.sequence_parallel,
args.decoder_tp_comm_overlap,
args.decoder_seq_length
)
if mp_padding_needed_for_text > 0:
tokens, position_ids, labels, loss_mask = [torch.nn.functional.pad(item, (0, mp_padding_needed_for_text)) for item in (tokens, position_ids, labels, loss_mask)]
packed_seq_params = context_parallel.get_packed_seq_params(tokens, img_seq_len, mp_padding_needed_for_text, cp_size, args.use_packed_sequence)
if packed_seq_params.qkv_format == 'thd':
# Reshape from [B,S] to [T,1]
tokens = (
tokens.contiguous()
.view(tokens.shape[0] * tokens.shape[1])
.unsqueeze(0)
)
position_ids = (
position_ids.contiguous()
.view(position_ids.shape[0] * position_ids.shape[1])
.unsqueeze(0)
)
labels = labels.view(labels.shape[0] * labels.shape[1]).unsqueeze(0)
loss_mask = loss_mask.view(
loss_mask.shape[0] * loss_mask.shape[1]
).unsqueeze(0)
attention_mask = None # Use the attention mask type defined in layer spec. Typically no mask for the vision model and causal mask for the vision model. attention_mask = None # Use the attention mask type defined in layer spec. Typically no mask for the vision model and causal mask for the vision model.
return batch["tokens"], batch["position_ids"], labels, images, loss_mask, attention_mask, image_token_mask, packed_seq_params return tokens, position_ids, labels, images, loss_mask, attention_mask, packed_seq_params
def forward_step(data_iterator, model: LLaVAModel): def forward_step(data_iterator, model: LLaVAModel):
...@@ -390,11 +365,11 @@ def forward_step(data_iterator, model: LLaVAModel): ...@@ -390,11 +365,11 @@ def forward_step(data_iterator, model: LLaVAModel):
# Get the batch. # Get the batch.
timers('batch-generator', log_level=2).start() timers('batch-generator', log_level=2).start()
tokens, position_ids, labels, images, loss_mask, attention_mask, image_token_mask, packed_seq_params = get_batch(data_iterator) tokens, position_ids, labels, images, loss_mask, attention_mask, packed_seq_params = get_batch(data_iterator)
timers('batch-generator').stop() timers('batch-generator').stop()
output_tensor, loss_mask = model( output_tensor, loss_mask = model(
images, tokens, position_ids, attention_mask, labels, loss_mask, image_token_mask=image_token_mask, packed_seq_params=packed_seq_params images, tokens, position_ids, attention_mask, labels, loss_mask, packed_seq_params=packed_seq_params
) )
return output_tensor, partial(loss_func, loss_mask) return output_tensor, partial(loss_func, loss_mask)
...@@ -419,6 +394,12 @@ def add_vlm_extra_args(parser): ...@@ -419,6 +394,12 @@ def add_vlm_extra_args(parser):
group.add_argument("--decoder-tp-comm-overlap", action="store_true", default=False, help="Enables the overlap of " group.add_argument("--decoder-tp-comm-overlap", action="store_true", default=False, help="Enables the overlap of "
"Tensor parallel communication and GEMM kernels in Decoder only. " "Tensor parallel communication and GEMM kernels in Decoder only. "
"Please provide decoder-seq-length when using this feature.") "Please provide decoder-seq-length when using this feature.")
group.add_argument(
"--use-packed-sequence",
action="store_true",
default=False,
help="Use packed sequence",
)
return parser return parser
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment