Commit 8fda607c authored by yuguo's avatar yuguo
Browse files

[DCU] Fix WS leak when init+destroy ub more than 1

parent 9da3621b
......@@ -225,9 +225,15 @@ def initialize_ub(
flush=True,
)
# Increase the workspace by the number of maximum concurrent streams
# Allocate cuBLAS workspace with expanded size for chunking in overlapping GEMM calls
global _cublas_workspace
_cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS)
if _cublas_workspace is None:
_cublas_workspace = get_workspace().repeat(_NUM_MAX_UB_STREAMS)
elif _cublas_workspace.numel() != get_cublas_workspace_size_bytes() * _NUM_MAX_UB_STREAMS:
# This ensures we don't do `.repeat()` on an already expanded workspace
_cublas_workspace = torch.empty(
get_cublas_workspace_size_bytes(), dtype=torch.uint8, device="cuda"
).repeat(_NUM_MAX_UB_STREAMS)
# Default buffer precision: AllGather buffers use fp8 when using fp8 recipe
layers_all_gather_overlap = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment