Unverified Commit 038beded authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[TPU] [Perf] Improve Memory Usage Estimation (#15671)


Signed-off-by: default avatarRobert Shaw <robshaw@redhat.com>
Co-authored-by: default avatarRobert Shaw <robshaw@redhat.com>
parent d03308be
......@@ -161,7 +161,13 @@ class TPUWorker:
# intermediate activations.
m = xm.get_memory_info(self.device)
total_memory_size = m["bytes_limit"]
profiled = m["peak_bytes_used"] # Weights + intermediate activations.
current_mem = m["bytes_used"]
# Ideally we would use profiled = m["peak_bytes_used"] to
# get weights + activations. But there is memory used during
# compilation / weight loading that impacts the peak and
# there is no way to reset peak memory in XLA, So we
# use the heuristic of 2% of weights.
profiled = current_mem * 1.02
# Calculate the TPU KV cache size based on profiling.
usable_memory_size = int(total_memory_size *
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment