Commit 72ec54e3 authored by one's avatar one
Browse files

Enhance profiling and warmup functionality in evo2 scripts

- Update run.sh to include new options for warmups and prompt stretching.
- Refactor test_evo2_generation_batched.py to improve trace output formatting and add support for warmup sequences.
- Adjust batch processing to include detailed profiling for each step.
parent 23db469a
...@@ -14,25 +14,25 @@ EVO_CMD="numactl -m 1 -N 1 \ ...@@ -14,25 +14,25 @@ EVO_CMD="numactl -m 1 -N 1 \
--model_name ${MODEL_NAME} \ --model_name ${MODEL_NAME} \
--local_path ${MODEL_PATH} \ --local_path ${MODEL_PATH} \
--n_tokens 500 \ --n_tokens 500 \
--n_warmups 1 \
--prompt_stretch \
--trace_gzip \ --trace_gzip \
--trace_logdir ./log/pt-trace/" --trace_logdir ./log/pt-trace/stretch"
run_all_tests() {
local batch_size=$1
for batch_size in 1 2; do
echo "================================================" echo "================================================"
echo "Running all tests for batch size ${batch_size}" echo "Running all tests for batch size ${batch_size}"
echo "================================================" echo "================================================"
mkdir -p log &> /dev/null mkdir -p log &>/dev/null
echo "==== Normal run ====" echo "==== Normal run ===="
${EVO_CMD} --batch_size ${batch_size} ${EVO_CMD} --batch_size ${batch_size}
# echo "==== Torch profiler trace for step 0 ====" for step in 0 1; do
# ${EVO_CMD} --batch_size ${batch_size} --trace --trace_step 0 echo "==== Torch profiler trace for prompt ${step} ===="
${EVO_CMD} --batch_size ${batch_size} --trace --trace_step ${step} \
# echo "==== Torch profiler trace for step 1 ====" --trace_file_prefix evo2-bw1000-bs${batch_size}-s${step}
# ${EVO_CMD} --batch_size ${batch_size} --trace --trace_step 1 done
# echo "==== Hipprof trace ====" # echo "==== Hipprof trace ===="
# hipprof --hip-trace -o log/trace-bs${batch_size} \ # hipprof --hip-trace -o log/trace-bs${batch_size} \
...@@ -43,7 +43,4 @@ run_all_tests() { ...@@ -43,7 +43,4 @@ run_all_tests() {
# --stats=true --trace=cuda \ # --stats=true --trace=cuda \
# -o log/trace-bs${batch_size} \ # -o log/trace-bs${batch_size} \
# ${EVO_CMD} --batch_size ${batch_size} # ${EVO_CMD} --batch_size ${batch_size}
} done
run_all_tests 1
run_all_tests 2
...@@ -90,7 +90,7 @@ def generate_and_score( ...@@ -90,7 +90,7 @@ def generate_and_score(
torch.cuda.synchronize() torch.cuda.synchronize()
step_time += time.perf_counter() step_time += time.perf_counter()
print( print(
f"[{i}:{min(i + batch_size, len(prompts)) - 1}] E2E Time for model.generate: {step_time:.3f} s" f"[{i}:{min(i + batch_size, len(prompts)) - 1}] E2E Time for model.generate (batch_size={batch_size}): {step_time:.3f} s"
) )
for j, decoded_seq in enumerate(generated.sequences): for j, decoded_seq in enumerate(generated.sequences):
...@@ -106,52 +106,6 @@ def generate_and_score( ...@@ -106,52 +106,6 @@ def generate_and_score(
return reshaped_scores return reshaped_scores
def custom_trace_handler(
dir_name="./log/pt-trace/",
worker_name=None,
use_gzip=False,
sort_by="self_device_time_total",
top_n=20,
):
tb_handler = torch.profiler.tensorboard_trace_handler(
dir_name=dir_name, worker_name=worker_name, use_gzip=use_gzip
)
field_fallbacks = {
"self_device_time_total": "self_cuda_time_total",
"device_time_total": "cuda_time_total",
"self_cuda_time_total": "self_cpu_time_total",
}
def handler(prof):
tb_handler(prof)
avgs = prof.key_averages()
final_sort_key = sort_by
if len(avgs) > 0:
sample_event = avgs[0]
# fallback
if not hasattr(sample_event, final_sort_key):
fallback_key = field_fallbacks.get(final_sort_key)
if fallback_key and hasattr(sample_event, fallback_key):
print(
f"[PROFILER] '{final_sort_key}' not found. Falling back to '{fallback_key}'."
)
final_sort_key = fallback_key
else:
print(
f"[PROFILER] Sort key '{final_sort_key}' invalid. Using default order."
)
final_sort_key = None
print(avgs.table(sort_by=final_sort_key, row_limit=top_n))
return handler
def generate_and_score_prof( def generate_and_score_prof(
*, *,
sequences, sequences,
...@@ -165,6 +119,7 @@ def generate_and_score_prof( ...@@ -165,6 +119,7 @@ def generate_and_score_prof(
trace_step=1, trace_step=1,
trace_logdir="./log/pt-trace/", trace_logdir="./log/pt-trace/",
trace_gzip=False, trace_gzip=False,
trace_file_prefix=None,
): ):
"""Prompt with first half, generate and score on 2nd half with torch profiler. """Prompt with first half, generate and score on 2nd half with torch profiler.
...@@ -185,7 +140,9 @@ def generate_and_score_prof( ...@@ -185,7 +140,9 @@ def generate_and_score_prof(
# 按需开启功能 # 按需开启功能
with torch.profiler.profile( with torch.profiler.profile(
schedule=torch.profiler.schedule(wait=0, warmup=trace_step, active=1, repeat=1), schedule=torch.profiler.schedule(wait=0, warmup=trace_step, active=1, repeat=1),
on_trace_ready=custom_trace_handler(dir_name=trace_logdir, use_gzip=trace_gzip), on_trace_ready=torch.profiler.tensorboard_trace_handler(
dir_name=trace_logdir, worker_name=trace_file_prefix, use_gzip=trace_gzip
),
record_shapes=False, record_shapes=False,
profile_memory=False, profile_memory=False,
with_stack=False, with_stack=False,
...@@ -209,7 +166,7 @@ def generate_and_score_prof( ...@@ -209,7 +166,7 @@ def generate_and_score_prof(
torch.cuda.synchronize() torch.cuda.synchronize()
step_time += time.perf_counter() step_time += time.perf_counter()
print( print(
f"[{i}:{min(i + batch_size, len(prompts)) - 1}] E2E Time for model.generate: {step_time:.3f} s" f"[{i}:{min(i + batch_size, len(prompts)) - 1}] E2E Time for model.generate (batch_size={batch_size}): {step_time:.3f} s"
) )
for j, decoded_seq in enumerate(generated.sequences): for j, decoded_seq in enumerate(generated.sequences):
...@@ -253,6 +210,12 @@ def main(): ...@@ -253,6 +210,12 @@ def main():
action="store_true", action="store_true",
help="Stretch all prompts to the longest prompt length", help="Stretch all prompts to the longest prompt length",
) )
parser.add_argument(
"--n_warmups",
type=int,
default=0,
help="Number of warmups to run",
)
parser.add_argument( parser.add_argument(
"--trace", "--trace",
action="store_true", action="store_true",
...@@ -275,6 +238,12 @@ def main(): ...@@ -275,6 +238,12 @@ def main():
action="store_true", action="store_true",
help="Gzip torch profiler trace output", help="Gzip torch profiler trace output",
) )
parser.add_argument(
"--trace_file_prefix",
type=str,
default=None,
help="Prefix for torch profiler trace output file",
)
args = parser.parse_args() args = parser.parse_args()
...@@ -300,20 +269,28 @@ def main(): ...@@ -300,20 +269,28 @@ def main():
# Debugging: replace all prompts with the longest prompt # Debugging: replace all prompts with the longest prompt
if args.prompt_stretch or args.batch_size > 1: if args.prompt_stretch or args.batch_size > 1:
longest_prompt = max(sequences, key=len) uniform_prompt = sequences[1] # length=7056
sequences = [longest_prompt] * len(sequences) sequences = [uniform_prompt] * len(sequences)
print( print(
f"[DEBUG] Using the longest prompt with len={len(longest_prompt)} for all sequences" f"[DEBUG] Using the uniform prompt with length {len(uniform_prompt)} for all sequences"
) )
# Warmup
if args.n_warmups > 0:
warmup_sequences = sequences[:1] * args.n_warmups
warmup_params = {**test_params, "n_tokens": 16}
generate_and_score(sequences=warmup_sequences, model=model, **warmup_params)
print(f"[DEBUG] Running {args.n_warmups} warmups with the first prompt")
if args.trace: if args.trace:
print("[TRACE] Using generate_and_score_prof with torch profiler") print("[TRACE] Using generate_and_score_prof with torch profiler")
scores = generate_and_score_prof( scores = generate_and_score_prof(
sequences=sequences, sequences=sequences,
model=model, model=model,
trace_step=args.trace_step, trace_step=args.trace_step,
trace_logdir=args.trace_logdir,
trace_gzip=args.trace_gzip, trace_gzip=args.trace_gzip,
trace_logdir=args.trace_logdir,
trace_file_prefix=args.trace_file_prefix,
**test_params, **test_params,
) )
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment