Unverified Commit 8b1cea0a authored by Gu Shiqiao's avatar Gu Shiqiao Committed by GitHub
Browse files

Fix variable name for torch version check (#534)

parent 19ac1216
...@@ -14,7 +14,7 @@ class WeightAsyncStreamManager(object): ...@@ -14,7 +14,7 @@ class WeightAsyncStreamManager(object):
self.offload_granularity = offload_granularity self.offload_granularity = offload_granularity
self.init_stream = torch.cuda.Stream(priority=0) self.init_stream = torch.cuda.Stream(priority=0)
torch_version = parse(torch.__version__.split("+")[0]) torch_version = parse(torch.__version__.split("+")[0])
if version >= parse("2.7"): if torch_version >= parse("2.7"):
self.cuda_load_stream = torch.cuda.Stream(priority=1) self.cuda_load_stream = torch.cuda.Stream(priority=1)
self.compute_stream = torch.cuda.Stream(priority=1) self.compute_stream = torch.cuda.Stream(priority=1)
else: else:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment