Unverified Commit af5daa09 authored by Jason Wang's avatar Jason Wang Committed by GitHub
Browse files

Add dtensor support for TE optimizers (#1171)



add dtensor support for te optimizers
Signed-off-by: default avatarjasonwan <jasonwan@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent df699655
......@@ -3,6 +3,7 @@
# See LICENSE for license information.
"""Multi-tensor apply entry."""
from torch.distributed._tensor import DTensor
class MultiTensorApply: # pylint: disable=too-few-public-methods
......@@ -12,6 +13,11 @@ class MultiTensorApply: # pylint: disable=too-few-public-methods
self.chunk_size = chunk_size
def __call__(self, op, noop_flag_buffer, tensor_lists, *args):
for i, ts in enumerate(tensor_lists):
for j, t in enumerate(ts):
if isinstance(t, DTensor):
tensor_lists[i][j] = t._local_tensor
return op(self.chunk_size, noop_flag_buffer, tensor_lists, *args)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment