Unverified Commit e7f4ace0 authored by zhanweidu's avatar zhanweidu Committed by GitHub
Browse files

fix non contiguous tensor value error in save_pretrained (#32422)


Signed-off-by: default avatarduzhanwei <duzhanwei@bytedance.com>
Co-authored-by: default avatarduzhanwei <duzhanwei@bytedance.com>
parent e4522fe3
...@@ -2746,7 +2746,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix ...@@ -2746,7 +2746,7 @@ class PreTrainedModel(nn.Module, ModuleUtilsMixin, GenerationMixin, PushToHubMix
if module_map: if module_map:
filename_to_tensors = logging.tqdm(filename_to_tensors, desc="Saving checkpoint shards") filename_to_tensors = logging.tqdm(filename_to_tensors, desc="Saving checkpoint shards")
for shard_file, tensors in filename_to_tensors: for shard_file, tensors in filename_to_tensors:
shard = {tensor: state_dict[tensor] for tensor in tensors} shard = {tensor: state_dict[tensor].contiguous() for tensor in tensors}
# remake shard with onloaded parameters if necessary # remake shard with onloaded parameters if necessary
if module_map: if module_map:
if accelerate_version < version.parse("0.31"): if accelerate_version < version.parse("0.31"):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment