Unverified Commit ac19b519 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[core] fix sleep mode in pytorch 2.6 (#13456)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent a1074b3e
......@@ -9,7 +9,7 @@
# the only successful approach is to call cuda driver API in C.
import dataclasses
from contextlib import contextmanager
from typing import Callable, Dict, Optional, Tuple, Union
from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch
......@@ -97,7 +97,7 @@ def use_memory_pool_with_allocator(
new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator)
with torch.cuda.memory.use_mem_pool(mem_pool):
yield mem_pool
yield mem_pool, new_alloc
class CuMemAllocator:
......@@ -142,6 +142,7 @@ class CuMemAllocator:
def __init__(self):
self.pointer_to_data: Dict[int, AllocationData] = {}
self.current_tag: str = CuMemAllocator.default_tag
self.allocator_and_pools: Dict[str, Any] = {}
def python_malloc_callback(self, allocation_handle: HandleType) -> None:
"""
......@@ -231,7 +232,13 @@ class CuMemAllocator:
old_tag = self.current_tag
self.current_tag = tag
with use_memory_pool_with_allocator(self.python_malloc_callback,
self.python_free_callback):
self.python_free_callback) as data:
# start to hit another PyTorch bug in PyTorch 2.6,
# possibly because of gc-related issue w.r.t. the allocator and
# the memory pool.
# to avoid the issue, we keep a reference of the data.
# see https://github.com/pytorch/pytorch/issues/146431 .
self.allocator_and_pools[tag] = data
yield
# PyTorch's bug, calling torch.cuda.empty_cache() will error
# when using pluggable allocator, see
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment