Commit af277ff1 authored by one's avatar one
Browse files

Enhance prompt handling and profiling in evo2 scripts

- Update run.sh to include new command-line options for prompt stretching and token limits.
- Modify test_evo2_generation_batched.py to adjust profiling settings and improve output formatting.
- Add support for stretching prompts to the longest length for batch processing.
parent c647fd9a
......@@ -13,7 +13,9 @@ EVO_CMD="numactl -m 1 -N 1 \
python -m evo2.test.test_evo2_generation_batched \
--model_name ${MODEL_NAME} \
--local_path ${MODEL_PATH} \
--trace_gzip true \
--n_tokens 500 \
--prompt_stretch \
--trace_gzip \
--trace_logdir ./log/pt-trace/"
run_all_tests() {
......
......@@ -90,7 +90,7 @@ def generate_and_score(
torch.cuda.synchronize()
step_time += time.perf_counter()
print(
f"[{i}:{min(i + batch_size, len(prompts))}) E2E Time for model.generate: {step_time:.3f} s"
f"[{i}:{min(i + batch_size, len(prompts)) - 1}] E2E Time for model.generate: {step_time:.3f} s"
)
for j, decoded_seq in enumerate(generated.sequences):
......@@ -182,14 +182,15 @@ def generate_and_score_prof(
print("\n[TRACE] Start profiling...")
# 按需开启功能
with torch.profiler.profile(
schedule=torch.profiler.schedule(wait=0, warmup=trace_step, active=1, repeat=1),
on_trace_ready=custom_trace_handler(dir_name=trace_logdir, use_gzip=trace_gzip),
record_shapes=True,
profile_memory=False, # 按需开启
with_stack=False, # 按需开启
with_flops=True,
with_modules=True,
record_shapes=False,
profile_memory=False,
with_stack=False,
with_flops=False,
with_modules=False,
) as prof:
for i in range(0, len(prompts), batch_size):
batch_prompts = prompts[i : i + batch_size]
......@@ -208,7 +209,7 @@ def generate_and_score_prof(
torch.cuda.synchronize()
step_time += time.perf_counter()
print(
f"[{i}:{min(i + batch_size, len(prompts))}) E2E Time for model.generate: {step_time:.3f} s"
f"[{i}:{min(i + batch_size, len(prompts)) - 1}] E2E Time for model.generate: {step_time:.3f} s"
)
for j, decoded_seq in enumerate(generated.sequences):
......@@ -247,6 +248,11 @@ def main():
parser.add_argument(
"--batch_size", type=int, default=1, help="Batch size for generation"
)
parser.add_argument(
"--prompt_stretch",
action="store_true",
help="Stretch all prompts to the longest prompt length",
)
parser.add_argument(
"--trace",
action="store_true",
......@@ -266,9 +272,8 @@ def main():
)
parser.add_argument(
"--trace_gzip",
type=bool,
default=False,
help="Gzip torch profiler trace output (default: False)",
action="store_true",
help="Gzip torch profiler trace output",
)
args = parser.parse_args()
......@@ -291,12 +296,14 @@ def main():
# Read and process sequences
sequences = read_prompts("prompts.csv")
print("[DEBUG] Prompt lengths:", [len(seq) for seq in sequences])
# Debugging: replace all prompts with the longest prompt
if args.batch_size > 1:
if args.prompt_stretch:
longest_prompt = max(sequences, key=len)
sequences = [longest_prompt] * len(sequences)
print(
f"[DEBUG] Using longest prompt len={len(longest_prompt)} for all sequences"
f"[DEBUG] Using the longest prompt with len={len(longest_prompt)} for all sequences"
)
if args.trace:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment