"ssh:/git@developer.sourcefind.cn:2222/OpenDAS/vllm_cscc.git" did not exist on "7ef40bb9832f4a8fca9f9924a35ae77a69ee7076"
Unverified Commit 951fdd66 authored by Woosuk Kwon's avatar Woosuk Kwon Committed by GitHub
Browse files

[TPU] Set per-rank XLA cache (#7533)

parent 2ecf7b17
...@@ -102,12 +102,12 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase): ...@@ -102,12 +102,12 @@ class TPUWorker(LoraNotSupportedWorkerBase, LocalOrDistributedWorkerBase):
# 30-40 graphs for decode. 128 is an arbitrary safe number. # 30-40 graphs for decode. 128 is an arbitrary safe number.
torch._dynamo.config.cache_size_limit = 128 torch._dynamo.config.cache_size_limit = 128
# Use persistent cache to avoid XLA recompilation. # Use persistent cache to avoid XLA recompilation.
# NOTE(woosuk): This does not completely eliminate the recompilation # NOTE(woosuk): Set per-rank cache path since different ranks
# overhead because dynamo does not cache the compiled results. # can have slightly different XLA graphs.
# NOTE(woosuk): Set readonly=False only for the rank 0 process to avoid world_size = self.parallel_config.world_size
# race conditions. per_rank_path = os.path.join(envs.VLLM_XLA_CACHE_PATH,
xr.initialize_cache(envs.VLLM_XLA_CACHE_PATH, f"tp{world_size}_rank{self.rank}")
readonly=not self.is_driver_worker) xr.initialize_cache(per_rank_path, readonly=False)
def load_model(self): def load_model(self):
self.model_runner.load_model() self.model_runner.load_model()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment