Unverified Commit ac19b519 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[core] fix sleep mode in pytorch 2.6 (#13456)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
parent a1074b3e
...@@ -9,7 +9,7 @@ ...@@ -9,7 +9,7 @@
# the only successful approach is to call cuda driver API in C. # the only successful approach is to call cuda driver API in C.
import dataclasses import dataclasses
from contextlib import contextmanager from contextlib import contextmanager
from typing import Callable, Dict, Optional, Tuple, Union from typing import Any, Callable, Dict, Optional, Tuple, Union
import torch import torch
...@@ -97,7 +97,7 @@ def use_memory_pool_with_allocator( ...@@ -97,7 +97,7 @@ def use_memory_pool_with_allocator(
new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func) new_alloc = get_pluggable_allocator(python_malloc_fn, python_free_func)
mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator) mem_pool = torch.cuda.memory.MemPool(new_alloc._allocator)
with torch.cuda.memory.use_mem_pool(mem_pool): with torch.cuda.memory.use_mem_pool(mem_pool):
yield mem_pool yield mem_pool, new_alloc
class CuMemAllocator: class CuMemAllocator:
...@@ -142,6 +142,7 @@ class CuMemAllocator: ...@@ -142,6 +142,7 @@ class CuMemAllocator:
def __init__(self): def __init__(self):
self.pointer_to_data: Dict[int, AllocationData] = {} self.pointer_to_data: Dict[int, AllocationData] = {}
self.current_tag: str = CuMemAllocator.default_tag self.current_tag: str = CuMemAllocator.default_tag
self.allocator_and_pools: Dict[str, Any] = {}
def python_malloc_callback(self, allocation_handle: HandleType) -> None: def python_malloc_callback(self, allocation_handle: HandleType) -> None:
""" """
...@@ -231,7 +232,13 @@ class CuMemAllocator: ...@@ -231,7 +232,13 @@ class CuMemAllocator:
old_tag = self.current_tag old_tag = self.current_tag
self.current_tag = tag self.current_tag = tag
with use_memory_pool_with_allocator(self.python_malloc_callback, with use_memory_pool_with_allocator(self.python_malloc_callback,
self.python_free_callback): self.python_free_callback) as data:
# start to hit another PyTorch bug in PyTorch 2.6,
# possibly because of gc-related issue w.r.t. the allocator and
# the memory pool.
# to avoid the issue, we keep a reference of the data.
# see https://github.com/pytorch/pytorch/issues/146431 .
self.allocator_and_pools[tag] = data
yield yield
# PyTorch's bug, calling torch.cuda.empty_cache() will error # PyTorch's bug, calling torch.cuda.empty_cache() will error
# when using pluggable allocator, see # when using pluggable allocator, see
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment