Commit c647fd9a authored by one's avatar one
Browse files

Enhance profiling capabilities in evo2 scripts

- Update run.sh to include trace logging options with gzip support.
- Modify test_evo2_generation_batched.py to add command-line arguments for trace log directory and gzip option.
- Refactor custom trace handler to utilize gzip compression for trace outputs.
parent 3bb2e7a5
......@@ -12,7 +12,9 @@ export MODEL_PATH=/models/arcinstitute/evo2_7b/evo2_7b.pt
EVO_CMD="numactl -m 1 -N 1 \
python -m evo2.test.test_evo2_generation_batched \
--model_name ${MODEL_NAME} \
--local_path ${MODEL_PATH}"
--local_path ${MODEL_PATH} \
--trace_gzip true \
--trace_logdir ./log/pt-trace/"
run_all_tests() {
local batch_size=$1
......
......@@ -109,11 +109,12 @@ def generate_and_score(
def custom_trace_handler(
dir_name="./log/pt-trace/",
worker_name=None,
use_gzip=False,
sort_by="self_device_time_total",
top_n=20,
):
tb_handler = torch.profiler.tensorboard_trace_handler(
dir_name=dir_name, worker_name=worker_name
dir_name=dir_name, worker_name=worker_name, use_gzip=use_gzip
)
field_fallbacks = {
......@@ -162,6 +163,8 @@ def generate_and_score_prof(
top_p=1.0,
batch_size=2,
trace_step=1,
trace_logdir="./log/pt-trace/",
trace_gzip=False,
):
"""Prompt with first half, generate and score on 2nd half with torch profiler.
......@@ -181,7 +184,7 @@ def generate_and_score_prof(
with torch.profiler.profile(
schedule=torch.profiler.schedule(wait=0, warmup=trace_step, active=1, repeat=1),
on_trace_ready=custom_trace_handler(dir_name="./log/pt-trace/"),
on_trace_ready=custom_trace_handler(dir_name=trace_logdir, use_gzip=trace_gzip),
record_shapes=True,
profile_memory=False, # 按需开启
with_stack=False, # 按需开启
......@@ -255,6 +258,18 @@ def main():
default=1,
help="Attach torch profiler to specific step (default: 1)",
)
parser.add_argument(
"--trace_logdir",
type=str,
default="./log/pt-trace/",
help="Directory for torch profiler trace output (default: ./log/pt-trace/)",
)
parser.add_argument(
"--trace_gzip",
type=bool,
default=False,
help="Gzip torch profiler trace output (default: False)",
)
args = parser.parse_args()
......@@ -290,6 +305,8 @@ def main():
sequences=sequences,
model=model,
trace_step=args.trace_step,
trace_logdir=args.trace_logdir,
trace_gzip=args.trace_gzip,
**test_params,
)
else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment