"vscode:/vscode.git/clone" did not exist on "8d2fca07e85af51f50e297d14e99318c1f665a9c"
Unverified Commit f8468b4f authored by qihqi's avatar qihqi Committed by GitHub
Browse files

For xla tensors, use an alternative way to get a unique id (#25802)

* For xla tensors, use an alternative way to get a unique id

Because xla tensors don't have storage.

* add is_torch_tpu_available check
parent 716bb2e3
......@@ -19,7 +19,7 @@ from packaging import version
from safetensors.torch import storage_ptr, storage_size
from torch import nn
from .utils import logging
from .utils import is_torch_tpu_available, logging
ALL_LAYERNORM_LAYERS = [nn.LayerNorm]
......@@ -285,4 +285,15 @@ def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]:
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
non-overlapping lifetimes may have the same id.
"""
return tensor.device, storage_ptr(tensor), storage_size(tensor)
if tensor.device.type == "xla" and is_torch_tpu_available():
# NOTE: xla tensors dont have storage
# use some other unique id to distinguish.
# this is a XLA tensor, it must be created using torch_xla's
# device. So the following import is safe:
import torch_xla
unique_id = torch_xla._XLAC._xla_get_tensor_id(tensor)
else:
unique_id = storage_ptr(tensor)
return tensor.device, unique_id, storage_size(tensor)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment