Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
zhaoyu6
sglang
Commits
ece72491
"vscode:/vscode.git/clone" did not exist on "8f6274c82be3221b45848836756223a918cd1d07"
Unverified
Commit
ece72491
authored
Dec 11, 2024
by
Ke Wen
Committed by
GitHub
Dec 11, 2024
Browse files
Make torch TP composable with torchao (#2436)
parent
0fb88aaa
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
66 additions
and
5 deletions
+66
-5
python/sglang/srt/model_parallel.py
python/sglang/srt/model_parallel.py
+66
-5
No files found.
python/sglang/srt/model_parallel.py
View file @
ece72491
...
...
@@ -2,18 +2,18 @@
Common utilities for torch model parallelism.
"""
from
typing
import
Optional
from
typing
import
Optional
,
Sequence
import
torch
import
torch.nn
as
nn
from
torch.distributed.device_mesh
import
DeviceMesh
try
:
from
torch.distributed.tensor
import
DTensor
,
Shard
import
torch.distributed.tensor
as
dt
except
ImportError
:
# torch 2.4 or older
from
torch.distributed._tensor
import
DTensor
,
Shard
import
torch.distributed._tensor
as
dt
from
torch.distributed._functional_collectives
import
AsyncCollectiveTensor
from
torch.distributed.tensor.parallel
import
(
ColwiseParallel
,
RowwiseParallel
,
...
...
@@ -21,6 +21,50 @@ from torch.distributed.tensor.parallel import (
)
def
_shard_tensor
(
full_tensor
:
torch
.
Tensor
,
device_mesh
:
DeviceMesh
,
placements
:
Sequence
[
dt
.
Shard
],
)
->
"dt.DTensor"
:
"""
Locally shards a full tensor based on indicated sharding arrangement, and
returns a DTensor containing the local shard.
.. warning:: This is a private API that is subject to change. It skips the
communication otherwise required by `distribute_tensor`. It is only
applicable to cases where all ranks have the same `full_tensor`. For
example, in distributed inference all ranks load from the same
checkpoint. This API will not check for data equality between ranks, it
is thus user's responsibility to ensure the `full_tensor` is the same
across ranks.
Args:
full_tensor (torch.Tensor): the full tensor to be sharded.
device_mesh (:class:`DeviceMesh`): DeviceMesh to place the
DTensor. Must have same dimension as the number of placements.
placements (Sequence[:class:`Shard`]): the placements that
describes how to place the local tensor on DeviceMesh.
Returns:
A :class:`DTensor` object with the shard as its local tensor.
Examples:
>>> # xdoctest: +SKIP("need world_size and rank")
>>> device_mesh = dist.init_device_mesh("cuda", (world_size,))
>>> full_tensor = torch.arange(world_size, device=f"cuda:{rank}")
>>> dtensor = _shard_tensor(full_tensor, device_mesh, [Shard(1)])
"""
shape
,
offset
=
dt
.
_utils
.
compute_local_shape_and_global_offset
(
full_tensor
.
shape
,
device_mesh
,
placements
)
slices
=
[
slice
(
cur_offset
,
cur_offset
+
cur_shape
)
for
cur_shape
,
cur_offset
in
zip
(
shape
,
offset
)
]
local_tensor
=
full_tensor
[
slices
]
return
dt
.
DTensor
.
from_local
(
local_tensor
,
device_mesh
,
placements
)
class
ColwiseParallelSharded
(
ColwiseParallel
):
"""
A version of ColwiseParallel where the local weight has been already
...
...
@@ -34,7 +78,7 @@ class ColwiseParallelSharded(ColwiseParallel):
# means Colwise as Linear is input * weight^T + bias, where
# weight would become Shard(1)
for
name
,
param
in
module
.
named_parameters
():
dtensor
=
DTensor
.
from_local
(
param
,
device_mesh
,
[
Shard
(
0
)])
dtensor
=
dt
.
DTensor
.
from_local
(
param
,
device_mesh
,
[
dt
.
Shard
(
0
)])
dist_param
=
torch
.
nn
.
Parameter
(
dtensor
,
requires_grad
=
False
)
module
.
register_parameter
(
name
,
dist_param
)
...
...
@@ -47,6 +91,23 @@ class RowwiseParallelMaybeWait(RowwiseParallel):
AsyncCollectiveTensor and custom ops, such as `class RMSNorm(CustomOp)`.
"""
def
_partition_linear_fn
(
self
,
name
,
module
,
device_mesh
):
# Rowwise shard weight to Shard(1), bias to Replicate(), weight be Shard(1)
# means Rowwise as nn.Linear is input * weight^T + bias, where
# weight would become Shard(0)
module
.
register_parameter
(
"weight"
,
nn
.
Parameter
(
_shard_tensor
(
module
.
weight
,
device_mesh
,
[
dt
.
Shard
(
1
)])),
)
if
getattr
(
module
,
"bias"
,
None
)
is
not
None
:
# The Linear module has bias
module
.
register_parameter
(
"bias"
,
nn
.
Parameter
(
dt
.
distribute_tensor
(
module
.
bias
,
device_mesh
,
[
dt
.
Replicate
()])
),
)
@
staticmethod
def
_prepare_output_fn
(
output_layouts
,
use_local_output
,
mod
,
outputs
,
device_mesh
):
outputs
=
super
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment