Unverified Commit bcb74974 authored by Gu Shiqiao's avatar Gu Shiqiao Committed by GitHub
Browse files

fix cuda stream (#532)

parent 5be4fe5a
......@@ -6,14 +6,20 @@ from collections import OrderedDict
import torch
from loguru import logger
from packaging.version import parse
class WeightAsyncStreamManager(object):
def __init__(self, offload_granularity):
self.offload_granularity = offload_granularity
self.init_stream = torch.cuda.Stream(priority=0)
self.cuda_load_stream = torch.cuda.Stream(priority=1)
self.compute_stream = torch.cuda.Stream(priority=1)
torch_version = parse(torch.__version__.split("+")[0])
if version >= parse("2.7"):
self.cuda_load_stream = torch.cuda.Stream(priority=1)
self.compute_stream = torch.cuda.Stream(priority=1)
else:
self.cuda_load_stream = torch.cuda.Stream(priority=0)
self.compute_stream = torch.cuda.Stream(priority=-1)
def init_cuda_buffer(self, blocks_cuda_buffer=None, phases_cuda_buffer=None):
if self.offload_granularity == "block":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment