"git@developer.sourcefind.cn:tsoc/superbenchmark.git" did not exist on "6076251816307fb00c9a65e368335d680451231c"
Commit 23db469a authored by one's avatar one
Browse files

Update prompt_stretch for evo2

- Remove prompt_stretch option from run.sh
- Adjust condition in test_evo2_generation_batched.py to allow prompt stretching based on batch size
parent af277ff1
...@@ -14,7 +14,6 @@ EVO_CMD="numactl -m 1 -N 1 \ ...@@ -14,7 +14,6 @@ EVO_CMD="numactl -m 1 -N 1 \
--model_name ${MODEL_NAME} \ --model_name ${MODEL_NAME} \
--local_path ${MODEL_PATH} \ --local_path ${MODEL_PATH} \
--n_tokens 500 \ --n_tokens 500 \
--prompt_stretch \
--trace_gzip \ --trace_gzip \
--trace_logdir ./log/pt-trace/" --trace_logdir ./log/pt-trace/"
......
...@@ -299,7 +299,7 @@ def main(): ...@@ -299,7 +299,7 @@ def main():
print("[DEBUG] Prompt lengths:", [len(seq) for seq in sequences]) print("[DEBUG] Prompt lengths:", [len(seq) for seq in sequences])
# Debugging: replace all prompts with the longest prompt # Debugging: replace all prompts with the longest prompt
if args.prompt_stretch: if args.prompt_stretch or args.batch_size > 1:
longest_prompt = max(sequences, key=len) longest_prompt = max(sequences, key=len)
sequences = [longest_prompt] * len(sequences) sequences = [longest_prompt] * len(sequences)
print( print(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment