Unverified Commit 0bcc195f authored by Mick's avatar Mick Committed by GitHub
Browse files

fix: minor fix TransportProxyTensor under tp (#8382)

parent 91e3d154
...@@ -92,9 +92,7 @@ class TransportProxyTensor(torch.Tensor): ...@@ -92,9 +92,7 @@ class TransportProxyTensor(torch.Tensor):
} }
state["tensor_data"] = None state["tensor_data"] = None
except Exception as e: except Exception as e:
print_warning_once( # Failed to get CUDA IPC handle (possibly tp). Falling back to default transport.
f"Warning: Failed to get CUDA IPC handle ({e}). Falling back to default transport."
)
state["metadata"]["transport_mode"] = "default" state["metadata"]["transport_mode"] = "default"
state["tensor_data"] = self.as_subclass(torch.Tensor) state["tensor_data"] = self.as_subclass(torch.Tensor)
else: else:
...@@ -751,7 +749,7 @@ def tensor_hash(tensor_list) -> int: ...@@ -751,7 +749,7 @@ def tensor_hash(tensor_list) -> int:
] ]
tensor = torch.concat(tensor_list) tensor = torch.concat(tensor_list)
if tensor.is_cuda: if tensor.is_cuda:
return gpu_tensor_hash(tensor) return gpu_tensor_hash(tensor.cuda())
tensor = tensor.detach().contiguous() tensor = tensor.detach().contiguous()
if tensor.dtype == torch.bfloat16: if tensor.dtype == torch.bfloat16:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment