Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
ece72491
Unverified
Commit
ece72491
authored
Dec 11, 2024
by
Ke Wen
Committed by
GitHub
Dec 11, 2024
Browse files
Make torch TP composable with torchao (#2436)
parent
0fb88aaa
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
66 additions
and
5 deletions
+66
-5
python/sglang/srt/model_parallel.py
python/sglang/srt/model_parallel.py
+66
-5
No files found.
python/sglang/srt/model_parallel.py
View file @
ece72491
...
...
@@ -2,18 +2,18 @@
Common utilities for torch model parallelism.
"""
from
typing
import
Optional
from
typing
import
Optional
,
Sequence
import
torch
import
torch.nn
as
nn
from
torch.distributed.device_mesh
import
DeviceMesh
try
:
from
torch.distributed.tensor
import
DTensor
,
Shard
import
torch.distributed.tensor
as
dt
except
ImportError
:
# torch 2.4 or older
from
torch.distributed._tensor
import
DTensor
,
Shard
import
torch.distributed._tensor
as
dt
from
torch.distributed._functional_collectives
import
AsyncCollectiveTensor
from
torch.distributed.tensor.parallel
import
(
ColwiseParallel
,
RowwiseParallel
,
...
...
@@ -21,6 +21,50 @@ from torch.distributed.tensor.parallel import (
)
def
_shard_tensor
(
full_tensor
:
torch
.
Tensor
,
device_mesh
:
DeviceMesh
,
placements
:
Sequence
[
dt
.
Shard
],
)
->
"dt.DTensor"
:
"""
Locally shards a full tensor based on indicated sharding arrangement, and
returns a DTensor containing the local shard.
.. warning:: This is a private API that is subject to change. It skips the
communication otherwise required by `distribute_tensor`. It is only
applicable to cases where all ranks have the same `full_tensor`. For
example, in distributed inference all ranks load from the same
checkpoint. This API will not check for data equality between ranks, it
is thus user's responsibility to ensure the `full_tensor` is the same
across ranks.
Args:
full_tensor (torch.Tensor): the full tensor to be sharded.
device_mesh (:class:`DeviceMesh`): DeviceMesh to place the
DTensor. Must have same dimension as the number of placements.
placements (Sequence[:class:`Shard`]): the placements that
describes how to place the local tensor on DeviceMesh.
Returns:
A :class:`DTensor` object with the shard as its local tensor.
Examples:
>>> # xdoctest: +SKIP("need world_size and rank")
>>> device_mesh = dist.init_device_mesh("cuda", (world_size,))
>>> full_tensor = torch.arange(world_size, device=f"cuda:{rank}")
>>> dtensor = _shard_tensor(full_tensor, device_mesh, [Shard(1)])
"""
shape
,
offset
=
dt
.
_utils
.
compute_local_shape_and_global_offset
(
full_tensor
.
shape
,
device_mesh
,
placements
)
slices
=
[
slice
(
cur_offset
,
cur_offset
+
cur_shape
)
for
cur_shape
,
cur_offset
in
zip
(
shape
,
offset
)
]
local_tensor
=
full_tensor
[
slices
]
return
dt
.
DTensor
.
from_local
(
local_tensor
,
device_mesh
,
placements
)
class
ColwiseParallelSharded
(
ColwiseParallel
):
"""
A version of ColwiseParallel where the local weight has been already
...
...
@@ -34,7 +78,7 @@ class ColwiseParallelSharded(ColwiseParallel):
# means Colwise as Linear is input * weight^T + bias, where
# weight would become Shard(1)
for
name
,
param
in
module
.
named_parameters
():
dtensor
=
DTensor
.
from_local
(
param
,
device_mesh
,
[
Shard
(
0
)])
dtensor
=
dt
.
DTensor
.
from_local
(
param
,
device_mesh
,
[
dt
.
Shard
(
0
)])
dist_param
=
torch
.
nn
.
Parameter
(
dtensor
,
requires_grad
=
False
)
module
.
register_parameter
(
name
,
dist_param
)
...
...
@@ -47,6 +91,23 @@ class RowwiseParallelMaybeWait(RowwiseParallel):
AsyncCollectiveTensor and custom ops, such as `class RMSNorm(CustomOp)`.
"""
def
_partition_linear_fn
(
self
,
name
,
module
,
device_mesh
):
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
# means Rowwise as nn.Linear is input * weight^T + bias, where
# weight would become Shard(0)
module
.
register_parameter
(
"weight"
,
nn
.
Parameter
(
_shard_tensor
(
module
.
weight
,
device_mesh
,
[
dt
.
Shard
(
1
)])),
)
if
getattr
(
module
,
"bias"
,
None
)
is
not
None
:
# The Linear module has bias
module
.
register_parameter
(
"bias"
,
nn
.
Parameter
(
dt
.
distribute_tensor
(
module
.
bias
,
device_mesh
,
[
dt
.
Replicate
()])
),
)
@
staticmethod
def
_prepare_output_fn
(
output_layouts
,
use_local_output
,
mod
,
outputs
,
device_mesh
):
outputs
=
super
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment