Unverified Commit f8468b4f authored by qihqi's avatar qihqi Committed by GitHub
Browse files

For xla tensors, use an alternative way to get a unique id (#25802)

* For xla tensors, use an alternative way to get a unique id

Because xla tensors don't have storage.

* add is_torch_tpu_available check
parent 716bb2e3
...@@ -19,7 +19,7 @@ from packaging import version ...@@ -19,7 +19,7 @@ from packaging import version
from safetensors.torch import storage_ptr, storage_size from safetensors.torch import storage_ptr, storage_size
from torch import nn from torch import nn
from .utils import logging from .utils import is_torch_tpu_available, logging
ALL_LAYERNORM_LAYERS = [nn.LayerNorm] ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
...@@ -285,4 +285,15 @@ def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]: ...@@ -285,4 +285,15 @@ def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]:
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
non-overlapping lifetimes may have the same id. non-overlapping lifetimes may have the same id.
""" """
return tensor.device, storage_ptr(tensor), storage_size(tensor) if tensor.device.type == "xla" and is_torch_tpu_available():
# NOTE: xla tensors dont have storage
# use some other unique id to distinguish.
# this is a XLA tensor, it must be created using torch_xla's
# device. So the following import is safe:
import torch_xla
unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor)
else:
unique_id = storage_ptr(tensor)
return tensor.device, unique_id, storage_size(tensor)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment