Commit fef9c0d9 authored by wxj's avatar wxj
Browse files

更新torchprof支持

parent 425a2473
Pipeline #2105 failed with stages
in 0 seconds
...@@ -19,16 +19,34 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1 ...@@ -19,16 +19,34 @@ export CUDA_DEVICE_MAX_CONNECTIONS=1
export NCCL_IB_HCA=mlx5_1,mlx5_2 export NCCL_IB_HCA=mlx5_1,mlx5_2
export NCCL_NET_GDR_LEVEL=SYS export NCCL_NET_GDR_LEVEL=SYS
export NCCL_NET_GDR_READ=0 export NCCL_NET_GDR_READ=0
export GLOG_minloglevel=3 # 打印error级别的nccl日志
source /opt/dtk/env.sh source /opt/dtk/env.sh
# te调用gemm需要导入hipblaslt库 # te调用gemm需要导入hipblaslt库
# export LD_LIBRARY_PATH=/data/hipblaslt-install-0904/lib:$LD_LIBRARY_PATH # export LD_LIBRARY_PATH=/data/hipblaslt-install-0904/lib:$LD_LIBRARY_PATH
# 更新rocblas
export LD_LIBRARY_PATH=/data/rocblas-install/lib:$LD_LIBRARY_PATH
# # prof采集添加同步
# export GPU_FLUSH_ON_EXECUTION=1
# export HIP_DIRECT_DISPATCH=0
CHECKPOINT_PATH=./tmp_7b #$1 #<Specify path> CHECKPOINT_PATH=./tmp_7b #$1 #<Specify path>
TENSORBOARD_LOGS_PATH=./tmp_7b #$2 #<Specify path> TENSORBOARD_LOGS_PATH=./tmp_7b #$2 #<Specify path>
DATA_PATH="/datasets/oscar-1GB-llama_text_document" #<Specify path and file prefix>_text_document DATA_PATH="/data/datasets/nemo_pretrain/oscar-1GB/oscar-1GB-llama_text_document" #<Specify path and file prefix>_text_document
# GPT_MODEL_ARGS=(
# --num-layers 32
# --hidden-size 5120
# --ffn-hidden-size 13824
# --num-attention-heads 40
# --seq-length 4096 #4096
# --max-position-embeddings 32768 #4096
# --num-query-groups 40
# --group-query-attention
# )
GPT_MODEL_ARGS=( GPT_MODEL_ARGS=(
--num-layers 36 --num-layers 6
--hidden-size 4096 --hidden-size 4096
--ffn-hidden-size 11008 --ffn-hidden-size 11008
--num-attention-heads 32 --num-attention-heads 32
...@@ -36,17 +54,18 @@ GPT_MODEL_ARGS=( ...@@ -36,17 +54,18 @@ GPT_MODEL_ARGS=(
--max-position-embeddings 4096 --max-position-embeddings 4096
) )
# export NVTE_FLASH_ATTN=1 # 走autlass # export NVTE_FLASH_ATTN=1 # 走cutlass
# export NVTE_FLASH_ATTN_TRITON=1 # 走triton_fa export NVTE_FLASH_ATTN_TRITON=1 # 走triton_fa
# --transformer-impl transformer_engine # --transformer-impl transformer_engine
# --use-mcore-models # --use-mcore-models
# --transformer-impl local
# --use-legacy-models
TRAINING_ARGS=( TRAINING_ARGS=(
--transformer-impl local --transformer-impl transformer_engine
--use-legacy-models --use-mcore-models
--micro-batch-size 1 --micro-batch-size 1
--global-batch-size 60 #240 #512 #64 --global-batch-size 6 #240 #60 #512 #64
--train-iters 100 --train-iters 10
--weight-decay 0.1 --weight-decay 0.1
--adam-beta1 0.9 --adam-beta1 0.9
--adam-beta2 0.95 --adam-beta2 0.95
...@@ -54,24 +73,32 @@ TRAINING_ARGS=( ...@@ -54,24 +73,32 @@ TRAINING_ARGS=(
--clip-grad 1.0 --clip-grad 1.0
--bf16 --bf16
--use-distributed-optimizer --use-distributed-optimizer
--ckpt-format torch
--disable-bias-linear --disable-bias-linear
--overlap-grad-reduce
--attention-dropout 0 --attention-dropout 0
--hidden-dropout 0 --hidden-dropout 0
--ddp-average-in-collective
--recompute-granularity full
--recompute-num-layers 5
--recompute-method block
--no-gradient-accumulation-fusion --no-gradient-accumulation-fusion
--add-qkv-bias
--swiglu --swiglu
--lr 3.0e-5 --lr 3.0e-5
--lr-decay-style cosine --lr-decay-style cosine
--min-lr 3.0e-6 --min-lr 3.0e-6
--lr-warmup-iters 1 --lr-warmup-iters 1
--ckpt-format torch
--ddp-average-in-collective
--recompute-granularity full
--recompute-num-layers 5 #0 #
--recompute-method block
--overlap-grad-reduce
--use-flash-attn-triton --use-flash-attn-triton
) )
# --use-flash-attn-ck # --add-qkv-bias # qwen
# --ckpt-format torch
# --ddp-average-in-collective
# --recompute-granularity full
# --recompute-num-layers 5
# --recompute-method block
# --overlap-grad-reduce
# --use-flash-attn-cutlass
# --use-flash-attn-triton # --use-flash-attn-triton
MODEL_PARALLEL_ARGS=( MODEL_PARALLEL_ARGS=(
...@@ -88,7 +115,7 @@ DATA_ARGS=( ...@@ -88,7 +115,7 @@ DATA_ARGS=(
--normalization RMSNorm --normalization RMSNorm
--no-position-embedding --no-position-embedding
--tokenizer-type Llama2Tokenizer --tokenizer-type Llama2Tokenizer
--tokenizer-model /path/to/llama2_7b_hf/tokenizer.model --tokenizer-model /data/model_weights/llama2_7b_hf/tokenizer.model
) )
EVAL_AND_LOGGING_ARGS=( EVAL_AND_LOGGING_ARGS=(
...@@ -102,6 +129,15 @@ EVAL_AND_LOGGING_ARGS=( ...@@ -102,6 +129,15 @@ EVAL_AND_LOGGING_ARGS=(
--tensorboard-dir $TENSORBOARD_LOGS_PATH --tensorboard-dir $TENSORBOARD_LOGS_PATH
) )
PROFILE_ARGS=(
--profile
--profile-step-start 4
--profile-step-end 5
--use-pytorch-profiler
--profile-ranks 0 3
--profile-dir prof_data
)
RANK=$OMPI_COMM_WORLD_RANK RANK=$OMPI_COMM_WORLD_RANK
LOCAL_RANK=$OMPI_COMM_WORLD_LOCAL_RANK LOCAL_RANK=$OMPI_COMM_WORLD_LOCAL_RANK
WORLD_SIZE=$OMPI_COMM_WORLD_SIZE WORLD_SIZE=$OMPI_COMM_WORLD_SIZE
...@@ -122,47 +158,51 @@ APP="python -u pretrain_gpt.py \ ...@@ -122,47 +158,51 @@ APP="python -u pretrain_gpt.py \
${DATA_ARGS[@]} \ ${DATA_ARGS[@]} \
${EVAL_AND_LOGGING_ARGS[@]} \ ${EVAL_AND_LOGGING_ARGS[@]} \
${DISTRIBUTED_ARGS[@]} \ ${DISTRIBUTED_ARGS[@]} \
${PROFILE_ARGS[@]} \
" "
export HIP_VISIBLE_DEVICES=4,5,6,7 # 0,1,2,3 # 4,5,6,7 #,
# export CUDA_VISIBLE_DEVICES=4,5,6,7 # 0,1,2,3,
# ${APP}
case ${LOCAL_RANK} in case ${LOCAL_RANK} in
[0]) [0])
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# ${APP} ${APP}
numactl --cpunodebind=0 --membind=0 ${APP} # numactl --cpunodebind=0 --membind=0 ${APP}
;; ;;
[1]) [1])
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# ${APP} ${APP}
numactl --cpunodebind=0 --membind=0 ${APP} # numactl --cpunodebind=0 --membind=0 ${APP}
;; ;;
[2]) [2])
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# ${APP} ${APP}
numactl --cpunodebind=0 --membind=0 ${APP} # numactl --cpunodebind=0 --membind=0 ${APP}
;; ;;
[3]) [3])
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# ${APP} ${APP}
numactl --cpunodebind=0 --membind=0 ${APP} # numactl --cpunodebind=0 --membind=0 ${APP}
;; ;;
[4]) # [4])
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# ${APP} # ${APP}
numactl --cpunodebind=0 --membind=0 ${APP} # # numactl --cpunodebind=0 --membind=0 ${APP}
;; # ;;
[5]) # [5])
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# ${APP} # ${APP}
numactl --cpunodebind=0 --membind=0 ${APP} # # numactl --cpunodebind=0 --membind=0 ${APP}
;; # ;;
[6]) # [6])
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# ${APP} # ${APP}
numactl --cpunodebind=0 --membind=0 ${APP} # # numactl --cpunodebind=0 --membind=0 ${APP}
;; # ;;
[7]) # [7])
export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 # export CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7
# ${APP} # ${APP}
numactl --cpunodebind=0 --membind=0 ${APP} # # numactl --cpunodebind=0 --membind=0 ${APP}
;; # ;;
esac esac
...@@ -643,7 +643,7 @@ def validate_args(args, defaults={}): ...@@ -643,7 +643,7 @@ def validate_args(args, defaults={}):
'--decoupled-lr and --decoupled-min-lr is not supported in legacy models.' '--decoupled-lr and --decoupled-min-lr is not supported in legacy models.'
# FlashAttention # FlashAttention
args.use_flash_attn = args.use_flash_attn_ck or args.use_flash_attn_triton args.use_flash_attn = args.use_flash_attn_cutlass or args.use_flash_attn_triton
# Legacy RoPE arguments # Legacy RoPE arguments
if args.use_rotary_position_embeddings: if args.use_rotary_position_embeddings:
...@@ -1265,6 +1265,8 @@ def _add_training_args(parser): ...@@ -1265,6 +1265,8 @@ def _add_training_args(parser):
dest='use_pytorch_profiler') dest='use_pytorch_profiler')
group.add_argument('--profile-ranks', nargs='+', type=int, default=[0], group.add_argument('--profile-ranks', nargs='+', type=int, default=[0],
help='Global ranks to profile.') help='Global ranks to profile.')
group.add_argument('--profile-dir', type=str, default="./",
help='profile dir to save.')
group.add_argument('--record-memory-history', action="store_true", default=False, group.add_argument('--record-memory-history', action="store_true", default=False,
help='Record memory history in last rank.') help='Record memory history in last rank.')
group.add_argument('--memory-snapshot-path', type=str, default="snapshot.pickle", group.add_argument('--memory-snapshot-path', type=str, default="snapshot.pickle",
...@@ -1358,7 +1360,7 @@ def _add_training_args(parser): ...@@ -1358,7 +1360,7 @@ def _add_training_args(parser):
group.add_argument('--cross-entropy-loss-fusion', action='store_true', group.add_argument('--cross-entropy-loss-fusion', action='store_true',
help='Enabled fusion of cross entropy loss calculation.', help='Enabled fusion of cross entropy loss calculation.',
dest='cross_entropy_loss_fusion') dest='cross_entropy_loss_fusion')
group.add_argument('--use-flash-attn-ck', action='store_true', group.add_argument('--use-flash-attn-cutlass', action='store_true',
help='use FlashAttention implementation of attention. ' help='use FlashAttention implementation of attention. '
'https://arxiv.org/abs/2205.14135') 'https://arxiv.org/abs/2205.14135')
group.add_argument('--use-flash-attn-triton', action='store_true', group.add_argument('--use-flash-attn-triton', action='store_true',
......
...@@ -135,6 +135,13 @@ def num_floating_point_operations(args, batch_size): ...@@ -135,6 +135,13 @@ def num_floating_point_operations(args, batch_size):
# - 2x: A GEMM of a m*n tensor with a n*k tensor requires 2mnk floating-point operations. # - 2x: A GEMM of a m*n tensor with a n*k tensor requires 2mnk floating-point operations.
expansion_factor = 3 * 2 * 2 expansion_factor = 3 * 2 * 2
# print(f"batch_size: {batch_size}, \
# query_projection_to_hidden_size_ratio: {query_projection_to_hidden_size_ratio}, \
# num_experts_routed_to: {num_experts_routed_to}, \
# gated_linear_multiplier: {gated_linear_multiplier}, \
# shared_expert_ffn_hidden_size: {shared_expert_ffn_hidden_size}, \
# gated_linear_multiplier: {gated_linear_multiplier}, \
# ")
return ( return (
expansion_factor expansion_factor
* batch_size * batch_size
...@@ -1214,8 +1221,8 @@ def post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteratio ...@@ -1214,8 +1221,8 @@ def post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteratio
if args.use_pytorch_profiler: if args.use_pytorch_profiler:
assert prof is not None assert prof is not None
prof.stop() prof.stop()
else: print_rank_0(f"prof stop!")
torch.cuda.cudart().cudaProfilerStop()
# Manual garbage collection. # Manual garbage collection.
if args.manual_gc: if args.manual_gc:
...@@ -1401,25 +1408,34 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, ...@@ -1401,25 +1408,34 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
prof = None prof = None
if args.profile and torch.distributed.get_rank() in args.profile_ranks and args.use_pytorch_profiler: if args.profile and torch.distributed.get_rank() in args.profile_ranks and args.use_pytorch_profiler:
def trace_handler(p):
from pathlib import Path
Path(f"{args.profile_dir}").mkdir(parents=True, exist_ok=True)
print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
p.export_chrome_trace("{path}/trace_rank{rank}_step{step}.json".format(
path=args.profile_dir, rank=torch.distributed.get_rank(), step=p.step_num))
prof = torch.profiler.profile( prof = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
schedule=torch.profiler.schedule( schedule=torch.profiler.schedule(
wait=max(args.profile_step_start-1, 0), wait=max(args.profile_step_start-1, 0),
warmup=1 if args.profile_step_start > 0 else 0, warmup=1 if args.profile_step_start > 0 else 0,
active=args.profile_step_end-args.profile_step_start, active=args.profile_step_end-args.profile_step_start,
repeat=1), repeat=1),
on_trace_ready=torch.profiler.tensorboard_trace_handler(args.tensorboard_dir), # record_shapes=True,
record_shapes=True, # with_stack=True,
with_stack=True) on_trace_ready=trace_handler,)
prof.start() prof.start()
# Run training iterations till done. # Run training iterations till done.
while iteration < args.train_iters: while iteration < args.train_iters:
if args.profile and torch.distributed.get_rank() in args.profile_ranks: if args.profile and torch.distributed.get_rank() in args.profile_ranks:
if args.use_pytorch_profiler: if args.use_pytorch_profiler:
prof.step() prof.step()
elif iteration == args.profile_step_start:
torch.cuda.cudart().cudaProfilerStart()
torch.autograd.profiler.emit_nvtx(record_shapes=True).__enter__()
maybe_finalize_async_save(blocking=False) maybe_finalize_async_save(blocking=False)
...@@ -1431,12 +1447,12 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, ...@@ -1431,12 +1447,12 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
if get_num_microbatches() != num_microbatches and iteration != 0: if get_num_microbatches() != num_microbatches and iteration != 0:
assert get_num_microbatches() > num_microbatches, \ assert get_num_microbatches() > num_microbatches, \
(f"Number of microbatches should be increasing due to batch size rampup; " (f"Number of microbatches should be increasing due to batch size rampup; "
f"instead going from {num_microbatches} to {get_num_microbatches()}") f"instead going from {num_microbatches} to {get_num_microbatches()}")
if args.save is not None: if args.save is not None:
save_checkpoint_and_time(iteration, model, optimizer, save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler, opt_param_scheduler,
num_floating_point_operations_so_far, num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator=train_data_iterator) checkpointing_context, train_data_iterator=train_data_iterator)
num_microbatches = get_num_microbatches() num_microbatches = get_num_microbatches()
update_num_microbatches(args.consumed_train_samples, consistency_check=True, verbose=True) update_num_microbatches(args.consumed_train_samples, consistency_check=True, verbose=True)
...@@ -1444,23 +1460,23 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, ...@@ -1444,23 +1460,23 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
args.curr_iteration = iteration args.curr_iteration = iteration
loss_dict, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad = \ loss_dict, skipped_iter, should_checkpoint, should_exit, exit_code, grad_norm, num_zeros_in_grad = \
train_step(forward_step_func, train_step(forward_step_func,
train_data_iterator, train_data_iterator,
model, model,
optimizer, optimizer,
opt_param_scheduler, opt_param_scheduler,
config) config)
if should_checkpoint: if should_checkpoint:
save_checkpoint_and_time(iteration, model, optimizer, save_checkpoint_and_time(iteration, model, optimizer,
opt_param_scheduler, opt_param_scheduler,
num_floating_point_operations_so_far, num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator=train_data_iterator) checkpointing_context, train_data_iterator=train_data_iterator)
if should_exit: if should_exit:
break break
# why is skipped_iter ignored? # why is skipped_iter ignored?
iteration += 1 iteration += 1
batch_size = mpu.get_data_parallel_world_size() * \ batch_size = mpu.get_data_parallel_world_size() * \
args.micro_batch_size * \ args.micro_batch_size * \
get_num_microbatches() get_num_microbatches()
args.consumed_train_samples += batch_size args.consumed_train_samples += batch_size
num_skipped_samples_in_batch = (get_current_global_batch_size() - num_skipped_samples_in_batch = (get_current_global_batch_size() -
get_current_running_global_batch_size()) get_current_running_global_batch_size())
...@@ -1486,11 +1502,11 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, ...@@ -1486,11 +1502,11 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
else: else:
learning_rate = param_group['lr'] learning_rate = param_group['lr']
report_memory_flag = training_log(loss_dict, total_loss_dict, report_memory_flag = training_log(loss_dict, total_loss_dict,
learning_rate, learning_rate,
decoupled_learning_rate, decoupled_learning_rate,
iteration, loss_scale, iteration, loss_scale,
report_memory_flag, skipped_iter, report_memory_flag, skipped_iter,
grad_norm, params_norm, num_zeros_in_grad) grad_norm, params_norm, num_zeros_in_grad)
# Evaluation. # Evaluation.
if args.eval_interval and iteration % args.eval_interval == 0 and \ if args.eval_interval and iteration % args.eval_interval == 0 and \
...@@ -1504,10 +1520,10 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, ...@@ -1504,10 +1520,10 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
prefix = f'iteration {iteration}' prefix = f'iteration {iteration}'
timers('eval-time', log_level=0).start(barrier=True) timers('eval-time', log_level=0).start(barrier=True)
evaluate_and_print_results(prefix, forward_step_func, evaluate_and_print_results(prefix, forward_step_func,
valid_data_iterator, model, valid_data_iterator, model,
iteration, process_non_loss_data_func, iteration, process_non_loss_data_func,
config, verbose=False, write_to_tensorboard=True, config, verbose=False, write_to_tensorboard=True,
non_loss_data_func=non_loss_data_func) non_loss_data_func=non_loss_data_func)
eval_duration += timers('eval-time').elapsed() eval_duration += timers('eval-time').elapsed()
eval_iterations += args.eval_iters eval_iterations += args.eval_iters
timers('eval-time').stop() timers('eval-time').stop()
...@@ -1527,12 +1543,12 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler, ...@@ -1527,12 +1543,12 @@ def train(forward_step_func, model, optimizer, opt_param_scheduler,
# Miscellaneous post-training-step functions (e.g., FT heartbeats, GC). # Miscellaneous post-training-step functions (e.g., FT heartbeats, GC).
# Some of these only happen at specific iterations. # Some of these only happen at specific iterations.
post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteration, prof, post_training_step_callbacks(model, optimizer, opt_param_scheduler, iteration, prof,
num_floating_point_operations_since_last_log_event) num_floating_point_operations_since_last_log_event)
# Checkpoint and decide whether to exit. # Checkpoint and decide whether to exit.
should_exit = checkpoint_and_decide_exit(model, optimizer, opt_param_scheduler, iteration, should_exit = checkpoint_and_decide_exit(model, optimizer, opt_param_scheduler, iteration,
num_floating_point_operations_so_far, num_floating_point_operations_so_far,
checkpointing_context, train_data_iterator) checkpointing_context, train_data_iterator)
if should_exit: if should_exit:
break break
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment