Commit af277ff1 authored by one's avatar one
Browse files

Enhance prompt handling and profiling in evo2 scripts

- Update run.sh to include new command-line options for prompt stretching and token limits.
- Modify test_evo2_generation_batched.py to adjust profiling settings and improve output formatting.
- Add support for stretching prompts to the longest length for batch processing.
parent c647fd9a
...@@ -13,7 +13,9 @@ EVO_CMD="numactl -m 1 -N 1 \ ...@@ -13,7 +13,9 @@ EVO_CMD="numactl -m 1 -N 1 \
python -m evo2.test.test_evo2_generation_batched \ python -m evo2.test.test_evo2_generation_batched \
--model_name ${MODEL_NAME} \ --model_name ${MODEL_NAME} \
--local_path ${MODEL_PATH} \ --local_path ${MODEL_PATH} \
--trace_gzip true \ --n_tokens 500 \
--prompt_stretch \
--trace_gzip \
--trace_logdir ./log/pt-trace/" --trace_logdir ./log/pt-trace/"
run_all_tests() { run_all_tests() {
......
...@@ -90,7 +90,7 @@ def generate_and_score( ...@@ -90,7 +90,7 @@ def generate_and_score(
torch.cuda.synchronize() torch.cuda.synchronize()
step_time += time.perf_counter() step_time += time.perf_counter()
print( print(
f"[{i}:{min(i + batch_size, len(prompts))}) E2E Time for model.generate: {step_time:.3f} s" f"[{i}:{min(i + batch_size, len(prompts)) - 1}] E2E Time for model.generate: {step_time:.3f} s"
) )
for j, decoded_seq in enumerate(generated.sequences): for j, decoded_seq in enumerate(generated.sequences):
...@@ -182,14 +182,15 @@ def generate_and_score_prof( ...@@ -182,14 +182,15 @@ def generate_and_score_prof(
print("\n[TRACE] Start profiling...") print("\n[TRACE] Start profiling...")
# 按需开启功能
with torch.profiler.profile( with torch.profiler.profile(
schedule=torch.profiler.schedule(wait=0, warmup=trace_step, active=1, repeat=1), schedule=torch.profiler.schedule(wait=0, warmup=trace_step, active=1, repeat=1),
on_trace_ready=custom_trace_handler(dir_name=trace_logdir, use_gzip=trace_gzip), on_trace_ready=custom_trace_handler(dir_name=trace_logdir, use_gzip=trace_gzip),
record_shapes=True, record_shapes=False,
profile_memory=False, # 按需开启 profile_memory=False,
with_stack=False, # 按需开启 with_stack=False,
with_flops=True, with_flops=False,
with_modules=True, with_modules=False,
) as prof: ) as prof:
for i in range(0, len(prompts), batch_size): for i in range(0, len(prompts), batch_size):
batch_prompts = prompts[i : i + batch_size] batch_prompts = prompts[i : i + batch_size]
...@@ -208,7 +209,7 @@ def generate_and_score_prof( ...@@ -208,7 +209,7 @@ def generate_and_score_prof(
torch.cuda.synchronize() torch.cuda.synchronize()
step_time += time.perf_counter() step_time += time.perf_counter()
print( print(
f"[{i}:{min(i + batch_size, len(prompts))}) E2E Time for model.generate: {step_time:.3f} s" f"[{i}:{min(i + batch_size, len(prompts)) - 1}] E2E Time for model.generate: {step_time:.3f} s"
) )
for j, decoded_seq in enumerate(generated.sequences): for j, decoded_seq in enumerate(generated.sequences):
...@@ -247,6 +248,11 @@ def main(): ...@@ -247,6 +248,11 @@ def main():
parser.add_argument( parser.add_argument(
"--batch_size", type=int, default=1, help="Batch size for generation" "--batch_size", type=int, default=1, help="Batch size for generation"
) )
parser.add_argument(
"--prompt_stretch",
action="store_true",
help="Stretch all prompts to the longest prompt length",
)
parser.add_argument( parser.add_argument(
"--trace", "--trace",
action="store_true", action="store_true",
...@@ -266,9 +272,8 @@ def main(): ...@@ -266,9 +272,8 @@ def main():
) )
parser.add_argument( parser.add_argument(
"--trace_gzip", "--trace_gzip",
type=bool, action="store_true",
default=False, help="Gzip torch profiler trace output",
help="Gzip torch profiler trace output (default: False)",
) )
args = parser.parse_args() args = parser.parse_args()
...@@ -291,12 +296,14 @@ def main(): ...@@ -291,12 +296,14 @@ def main():
# Read and process sequences # Read and process sequences
sequences = read_prompts("prompts.csv") sequences = read_prompts("prompts.csv")
print("[DEBUG] Prompt lengths:", [len(seq) for seq in sequences])
# Debugging: replace all prompts with the longest prompt # Debugging: replace all prompts with the longest prompt
if args.batch_size > 1: if args.prompt_stretch:
longest_prompt = max(sequences, key=len) longest_prompt = max(sequences, key=len)
sequences = [longest_prompt] * len(sequences) sequences = [longest_prompt] * len(sequences)
print( print(
f"[DEBUG] Using longest prompt len={len(longest_prompt)} for all sequences" f"[DEBUG] Using the longest prompt with len={len(longest_prompt)} for all sequences"
) )
if args.trace: if args.trace:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment