"vscode:/vscode.git/clone" did not exist on "4026ae31e910d50da2b80c1c386f1d1db7f1b7d8"
Unverified Commit beebf474 authored by Siyuan Liu's avatar Siyuan Liu Committed by GitHub
Browse files

[TPU][Profiler] Support start_profile/stop_profile in TPU worker (#13988)


Signed-off-by: default avatarSiyuan Liu <lsiyuan@google.com>
Co-authored-by: default avatarmgoin <mgoin64@gmail.com>
parent f89978ad
......@@ -17,8 +17,9 @@ ray[default]
--find-links https://storage.googleapis.com/libtpu-releases/index.html
--find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html
--find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
torch==2.7.0.dev20250226+cpu
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250226+cxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250227%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250227%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-2.7.0.dev20250227%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250227%2Bcxx11-cp39-cp39-linux_x86_64.whl ; python_version == "3.9"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250227%2Bcxx11-cp310-cp310-linux_x86_64.whl ; python_version == "3.10"
torch_xla[tpu, pallas] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.7.0.dev20250227%2Bcxx11-cp311-cp311-linux_x86_64.whl ; python_version == "3.11"
......@@ -7,6 +7,7 @@ import torch
import torch.distributed
import torch.nn as nn
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.runtime as xr
import vllm.envs as envs
......@@ -65,6 +66,15 @@ class TPUWorker:
from vllm.utils import init_cached_hf_modules
init_cached_hf_modules()
self.profiler = None
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
# For TPU, we can only have 1 active profiler session for 1 profiler
# server. So we only profile on rank0.
self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
self.profile_dir)
self.profiler = xp.start_server(9012)
def init_device(self):
os.environ["PJRT_DEVICE"] = "TPU"
torch.set_grad_enabled(False)
......@@ -152,6 +162,15 @@ class TPUWorker:
output = self.model_runner.execute_model(scheduler_output)
return output if self.is_driver_worker else None
def profile(self, is_start: bool = True):
if self.rank < 1:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
if is_start:
xp.start_trace(self.profile_dir)
else:
xp.stop_trace()
def load_model(self) -> None:
self.model_runner.load_model()
......
......@@ -5,6 +5,7 @@ from typing import List, Optional, Tuple, Union
import torch
import torch_xla.core.xla_model as xm
import torch_xla.debug.profiler as xp
import torch_xla.runtime as xr
import vllm.envs as envs
......@@ -93,6 +94,27 @@ class TPUWorker(LoRANotSupportedWorkerBase, LocalOrDistributedWorkerBase):
f"tp{world_size}_rank{rank}")
xr.initialize_cache(per_rank_path, readonly=False)
self.profiler = None
if envs.VLLM_TORCH_PROFILER_DIR and self.rank < 1:
# For TPU, we can only have 1 active profiler session for 1 profiler
# server. So we only profile on rank0.
self.profile_dir = envs.VLLM_TORCH_PROFILER_DIR
logger.info("Profiling enabled. Traces will be saved to: %s",
self.profile_dir)
self.profiler = xp.start_server(9012)
def start_profile(self):
if self.rank < 1:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
xp.start_trace(self.profile_dir)
def stop_profile(self):
if self.rank < 1:
if self.profiler is None:
raise RuntimeError("Profiler is not enabled.")
xp.stop_trace()
def load_model(self):
self.model_runner.load_model()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment