"...resnet50_tensorflow.git" did not exist on "7567d57572968d6b875c7218a7910ab6a3aa1b55"
Unverified Commit d68d6665 authored by Thomas Wang's avatar Thomas Wang Committed by GitHub
Browse files

Support shared tensors (#23871)



* Suport shared storage

* Really be sure we have the same storage

* Make style

* - Refactor storage identifier mechanism
 - Group everything into a single for loop

* Make style

* PR

* make style

* Update src/transformers/pytorch_utils.py
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>

---------
Co-authored-by: default avatarSylvain Gugger <35901082+sgugger@users.noreply.github.com>
parent 68d53bc7
...@@ -41,6 +41,7 @@ from .pytorch_utils import ( # noqa: F401 ...@@ -41,6 +41,7 @@ from .pytorch_utils import ( # noqa: F401
Conv1D, Conv1D,
apply_chunking_to_forward, apply_chunking_to_forward,
find_pruneable_heads_and_indices, find_pruneable_heads_and_indices,
id_tensor_storage,
prune_conv1d_layer, prune_conv1d_layer,
prune_layer, prune_layer,
prune_linear_layer, prune_linear_layer,
...@@ -304,26 +305,31 @@ def shard_checkpoint( ...@@ -304,26 +305,31 @@ def shard_checkpoint(
""" """
max_shard_size = convert_file_size_to_int(max_shard_size) max_shard_size = convert_file_size_to_int(max_shard_size)
sharded_state_dicts = [] sharded_state_dicts = [{}]
current_block = {} last_block_size = 0
current_block_size = 0
total_size = 0 total_size = 0
storage_id_to_block = {}
for key, weight in state_dict.items(): for key, weight in state_dict.items():
storage_id = id_tensor_storage(weight)
# If a `weight` shares the same underlying storage as another tensor, we put `weight` in the same `block`
if storage_id in storage_id_to_block:
block_id = storage_id_to_block[storage_id]
sharded_state_dicts[block_id][key] = weight
continue
weight_size = weight.numel() * dtype_byte_size(weight.dtype) weight_size = weight.numel() * dtype_byte_size(weight.dtype)
# If this weight is going to tip up over the maximal size, we split. # If this weight is going to tip up over the maximal size, we split.
if current_block_size + weight_size > max_shard_size: if last_block_size + weight_size > max_shard_size:
sharded_state_dicts.append(current_block) sharded_state_dicts.append({})
current_block = {} last_block_size = 0
current_block_size = 0
current_block[key] = weight sharded_state_dicts[-1][key] = weight
current_block_size += weight_size last_block_size += weight_size
total_size += weight_size total_size += weight_size
storage_id_to_block[storage_id] = len(sharded_state_dicts) - 1
# Add the last block
sharded_state_dicts.append(current_block)
# If we only have one shard, we return it # If we only have one shard, we return it
if len(sharded_state_dicts) == 1: if len(sharded_state_dicts) == 1:
......
...@@ -16,6 +16,7 @@ from typing import Callable, List, Optional, Set, Tuple, Union ...@@ -16,6 +16,7 @@ from typing import Callable, List, Optional, Set, Tuple, Union
import torch import torch
from packaging import version from packaging import version
from safetensors.torch import storage_ptr, storage_size
from torch import nn from torch import nn
from .utils import logging from .utils import logging
...@@ -277,3 +278,13 @@ def meshgrid( ...@@ -277,3 +278,13 @@ def meshgrid(
if indexing != "ij": if indexing != "ij":
raise ValueError('torch.meshgrid only supports `indexing="ij"` for torch<1.10.') raise ValueError('torch.meshgrid only supports `indexing="ij"` for torch<1.10.')
return torch.meshgrid(*tensors) return torch.meshgrid(*tensors)
def id_tensor_storage(tensor: torch.Tensor) -> Tuple[torch.device, int, int]:
"""
Unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. For
example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is
guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
non-overlapping lifetimes may have the same id.
"""
return tensor.device, storage_ptr(tensor), storage_size(tensor)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment