Unverified Commit 556b9b7e authored by Jiarui Fang's avatar Jiarui Fang Committed by GitHub
Browse files

[hotfix] Dist Mgr gather torch version (#1284)

* make it faster

* [hotfix] torchvison fx tests

* [hotfix] rename duplicated named test_gpt.py

* [hotfix] dist mgr torch version
parent 7e8114a8
...@@ -88,11 +88,13 @@ class DistSpecManager: ...@@ -88,11 +88,13 @@ class DistSpecManager:
torch.Tensor: a replicated tensor. torch.Tensor: a replicated tensor.
""" """
assert old_dist_spec.placement.value == 's', f"The old_dist_spec of DistSpecManager._gather must be SHARD!" assert old_dist_spec.placement.value == 's', f"The old_dist_spec of DistSpecManager._gather must be SHARD!"
if version.parse(torch.__version__) < version.parse("1.11.0"): is_cpu_tensor = False
if tensor.device.type == 'cpu':
# pytorch lower than 1.11 dose not support gather a cpu tensor. # pytorch lower than 1.11 dose not support gather a cpu tensor.
# Therefore, we transfer tensor to GPU before gather. # Therefore, we transfer tensor to GPU before gather.
saved_dev = tensor.device saved_dev = tensor.device
tensor.data = tensor.data.cuda() tensor.data = tensor.data.cuda()
is_cpu_tensor = True
buffer = [torch.empty_like(tensor) for _ in range(pg.tp_world_size())] buffer = [torch.empty_like(tensor) for _ in range(pg.tp_world_size())]
assert tensor.device.type == 'cuda' assert tensor.device.type == 'cuda'
...@@ -106,7 +108,7 @@ class DistSpecManager: ...@@ -106,7 +108,7 @@ class DistSpecManager:
buffer = new_buffer buffer = new_buffer
assert len(buffer) == 1 assert len(buffer) == 1
if version.parse(torch.__version__) < version.parse("1.11.0"): if is_cpu_tensor:
buffer[0].data = buffer[0].data.to(saved_dev) buffer[0].data = buffer[0].data.to(saved_dev)
return buffer[0] return buffer[0]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment