Unverified Commit ece72491 authored by Ke Wen's avatar Ke Wen Committed by GitHub
Browse files

Make torch TP composable with torchao (#2436)

parent 0fb88aaa
...@@ -2,18 +2,18 @@ ...@@ -2,18 +2,18 @@
Common utilities for torch model parallelism. Common utilities for torch model parallelism.
""" """
from typing import Optional from typing import Optional, Sequence
import torch import torch
import torch.nn as nn
from torch.distributed.device_mesh import DeviceMesh from torch.distributed.device_mesh import DeviceMesh
try: try:
from torch.distributed.tensor import DTensor, Shard import torch.distributed.tensor as dt
except ImportError: except ImportError:
# torch 2.4 or older # torch 2.4 or older
from torch.distributed._tensor import DTensor, Shard import torch.distributed._tensor as dt
from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed.tensor.parallel import ( from torch.distributed.tensor.parallel import (
ColwiseParallel, ColwiseParallel,
RowwiseParallel, RowwiseParallel,
...@@ -21,6 +21,50 @@ from torch.distributed.tensor.parallel import ( ...@@ -21,6 +21,50 @@ from torch.distributed.tensor.parallel import (
) )
def _shard_tensor(
full_tensor: torch.Tensor,
device_mesh: DeviceMesh,
placements: Sequence[dt.Shard],
) -> "dt.DTensor":
"""
Locally shards a full tensor based on indicated sharding arrangement, and
returns a DTensor containing the local shard.
.. warning:: This is a private API that is subject to change. It skips the
communication otherwise required by `distribute_tensor`. It is only
applicable to cases where all ranks have the same `full_tensor`. For
example, in distributed inference all ranks load from the same
checkpoint. This API will not check for data equality between ranks, it
is thus user's responsibility to ensure the `full_tensor` is the same
across ranks.
Args:
full_tensor (torch.Tensor): the full tensor to be sharded.
device_mesh (:class:`DeviceMesh`): DeviceMesh to place the
DTensor. Must have same dimension as the number of placements.
placements (Sequence[:class:`Shard`]): the placements that
describes how to place the local tensor on DeviceMesh.
Returns:
A :class:`DTensor` object with the shard as its local tensor.
Examples:
>>> # xdoctest: +SKIP("need world_size and rank")
>>> device_mesh = dist.init_device_mesh("cuda", (world_size,))
>>> full_tensor = torch.arange(world_size, device=f"cuda:{rank}")
>>> dtensor = _shard_tensor(full_tensor, device_mesh, [Shard(1)])
"""
shape, offset = dt._utils.compute_local_shape_and_global_offset(
full_tensor.shape, device_mesh, placements
)
slices = [
slice(cur_offset, cur_offset + cur_shape)
for cur_shape, cur_offset in zip(shape, offset)
]
local_tensor = full_tensor[slices]
return dt.DTensor.from_local(local_tensor, device_mesh, placements)
class ColwiseParallelSharded(ColwiseParallel): class ColwiseParallelSharded(ColwiseParallel):
""" """
A version of ColwiseParallel where the local weight has been already A version of ColwiseParallel where the local weight has been already
...@@ -34,7 +78,7 @@ class ColwiseParallelSharded(ColwiseParallel): ...@@ -34,7 +78,7 @@ class ColwiseParallelSharded(ColwiseParallel):
# means Colwise as Linear is input * weight^T + bias, where # means Colwise as Linear is input * weight^T + bias, where
# weight would become Shard(1) # weight would become Shard(1)
for name, param in module.named_parameters(): for name, param in module.named_parameters():
dtensor = DTensor.from_local(param, device_mesh, [Shard(0)]) dtensor = dt.DTensor.from_local(param, device_mesh, [dt.Shard(0)])
dist_param = torch.nn.Parameter(dtensor, requires_grad=False) dist_param = torch.nn.Parameter(dtensor, requires_grad=False)
module.register_parameter(name, dist_param) module.register_parameter(name, dist_param)
...@@ -47,6 +91,23 @@ class RowwiseParallelMaybeWait(RowwiseParallel): ...@@ -47,6 +91,23 @@ class RowwiseParallelMaybeWait(RowwiseParallel):
AsyncCollectiveTensor and custom ops, such as `class RMSNorm(CustomOp)`. AsyncCollectiveTensor and custom ops, such as `class RMSNorm(CustomOp)`.
""" """
def _partition_linear_fn(self, name, module, device_mesh):
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
# means Rowwise as nn.Linear is input * weight^T + bias, where
# weight would become Shard(0)
module.register_parameter(
"weight",
nn.Parameter(_shard_tensor(module.weight, device_mesh, [dt.Shard(1)])),
)
if getattr(module, "bias", None) is not None:
# The Linear module has bias
module.register_parameter(
"bias",
nn.Parameter(
dt.distribute_tensor(module.bias, device_mesh, [dt.Replicate()])
),
)
@staticmethod @staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh): def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
outputs = super( outputs = super(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment