Unverified Commit cf248976 authored by Ke Wen's avatar Ke Wen Committed by GitHub
Browse files

Add Tensor Parallel to torch_native_llama (#1876)

parent e5c67150
......@@ -220,8 +220,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
return reqs
@torch.inference_mode()
def extend(reqs, model_runner):
def _extend(reqs, model_runner):
batch = ScheduleBatch.init_new(
reqs=reqs,
req_to_token_pool=model_runner.req_to_token_pool,
......@@ -237,8 +236,15 @@ def extend(reqs, model_runner):
return next_token_ids, logits_output.next_token_logits, batch
@torch.inference_mode()
def decode(input_token_ids, batch, model_runner):
def extend(reqs, model_runner):
# Disable inference mode for now when torch TP is applied. We can remove
# this workaround once DTensor adds support for inference mode.
use_inf_mode = not model_runner.torch_tp_applied
with torch.inference_mode(use_inf_mode):
return _extend(reqs, model_runner)
def _decode(input_token_ids, batch, model_runner):
batch.output_ids = input_token_ids
batch.prepare_for_decode()
model_worker_batch = batch.get_model_worker_batch()
......@@ -248,6 +254,14 @@ def decode(input_token_ids, batch, model_runner):
return next_token_ids, logits_output.next_token_logits
def decode(input_token_ids, batch, model_runner):
# Disable inference mode for now when torch TP is applied. We can remove
# this workaround once DTensor adds support for inference mode.
use_inf_mode = not model_runner.torch_tp_applied
with torch.inference_mode(use_inf_mode):
return _decode(input_token_ids, batch, model_runner)
def correctness_test(
server_args,
port_args,
......
......@@ -148,6 +148,15 @@ class ModelRunner:
min_per_gpu_memory = self.init_torch_distributed()
self.sampler = Sampler()
self.load_model()
# Apply torch TP if model supports it
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
if self.tp_size > 1 and supports_torch_tp:
self.apply_torch_tp()
self.torch_tp_applied = True
else:
self.torch_tp_applied = False
if server_args.lora_paths is not None:
self.init_lora_manager()
self.init_memory_pool(
......@@ -551,6 +560,13 @@ class ModelRunner:
logger.info("Capture cuda graph begin. This can take up to several minutes.")
self.cuda_graph_runner = CudaGraphRunner(self)
def apply_torch_tp(self):
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
from sglang.srt.model_parallel import tensor_parallel
device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
tensor_parallel(self.model, device_mesh)
def forward_decode(self, forward_batch: ForwardBatch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
return self.cuda_graph_runner.replay(forward_batch)
......
"""
Common utilities for torch model parallelism.
"""
from typing import Optional, Sequence
import torch
from torch.distributed.device_mesh import DeviceMesh
try:
from torch.distributed.tensor import DTensor, Shard
except ImportError:
# torch 2.4 or older
from torch.distributed._tensor import DTensor, Shard
from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)
class ColwiseParallelSharded(ColwiseParallel):
"""
A version of ColwiseParallel where the local weight has been already
sharded. This is used for the fused wqkv case, where during loading, we
already sharded wq, wk, wv before fusing them.
"""
# Override the _partition_linear_fn in ColwiseParallel
def _partition_linear_fn(self, name, module, device_mesh):
# colwise shard weight/bias to Shard(0), weight be Shard(0)
# means Colwise as Linear is input * weight^T + bias, where
# weight would become Shard(1)
for name, param in module.named_parameters():
dtensor = DTensor.from_local(param, device_mesh, [Shard(0)])
dist_param = torch.nn.Parameter(dtensor, requires_grad=False)
module.register_parameter(name, dist_param)
class RowwiseParallelMaybeWait(RowwiseParallel):
"""
A version of RowwiseParallel that waits for the output (establish dependency
between comm stream and compute stream in CUDA sense) before going into the
next op. This is needed to workaround the current interaction between
AsyncCollectiveTensor and custom ops, such as `class RMSNorm(CustomOp)`.
"""
@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
outputs = super(
RowwiseParallelMaybeWait, RowwiseParallelMaybeWait
)._prepare_output_fn(
output_layouts, use_local_output, mod, outputs, device_mesh
)
# wait for the output to be ready
if isinstance(outputs, AsyncCollectiveTensor):
return outputs.wait()
else:
return outputs
def tensor_parallel(
module: torch.nn.Module,
device_mesh: Optional[DeviceMesh] = None,
):
"""
Tensor parallelize the model across the given device mesh.
Args:
module (`torch.nn.Module`):
The module to tensor parallelize.
device_mesh (`torch.distributed.DeviceMesh`):
The device mesh to use for tensor parallelism.
"""
# Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
# No op if `_tp_plan` attribute does not exist under the module.
# This is a helper function to be used with `model.apply` to recursively
# parallelize a model.
def tplize(mod: torch.nn.Module) -> None:
tp_plan = getattr(mod, "_tp_plan", None)
if tp_plan is None:
return
for child_name, tp_style in tp_plan.items():
submod = mod.get_submodule(child_name)
if tp_style == "Colwise":
parallelize_module(submod, device_mesh, ColwiseParallel())
elif tp_style == "Rowwise":
parallelize_module(submod, device_mesh, RowwiseParallelMaybeWait())
elif tp_style == "Colwise_Sharded":
parallelize_module(submod, device_mesh, ColwiseParallelSharded())
else:
raise ValueError(f"Unknown TP style {tp_style}")
# `apply` is a native method of `nn.Module` that recursively applies a
# function to every submodule.
module.apply(tplize)
......@@ -17,6 +17,31 @@ limitations under the License.
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
"""Inference-only LLaMA model compatible with HuggingFace weights."""
# PyTorch Tensor Parallel Available for This Model
"""
This model supports tensor parallelism (TP) using the PyTorch tensor parallel package.
Reference: https://pytorch.org/docs/stable/distributed.tensor.parallel.html
Here is a quick example to enable TP:
```python
from sglang.srt.model_parallel import tensor_parallel
device_mesh = torch.distributed.init_device_mesh("cuda", (tp_size,))
tensor_parallel(model, device_mesh)
```
An end-to-end example can be found in `python/sglang/bench_latency.py`.
You can run it with the following command:
```bash
$ python3 -m sglang.bench_latency --correct \
--model meta-llama/Meta-Llama-3-8B \
--json-model-override-args '{"architectures": ["TorchNativeLlamaForCausalLM"]}' \
--tensor-parallel-size 2 \
--disable-cuda-graph
```
We will eanble CUDA Graph support soon.
"""
import types
from typing import Any, Dict, Iterable, Optional, Tuple
......@@ -24,7 +49,10 @@ import torch
from torch import nn
from torch.nn.parameter import Parameter
from transformers import LlamaConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
......@@ -41,35 +69,45 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
def gate_up_proj_weight_loader(
self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None,
loaded_shard_id: int,
):
if loaded_shard_id is None:
shard_offsets: List[Tuple[int, int, int]] = []
for i, output_size in enumerate(self.output_sizes):
shard_offsets.append((i, current_shard_offset, output_size))
current_shard_offset += output_size
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size
)
self.weight_loader(param, loaded_weight_shard, shard_id)
else:
assert loaded_shard_id < len(self.output_sizes)
param_data = param.data
shard_size = loaded_weight.shape[0]
shard_offset = loaded_shard_id * shard_size
param_data = param_data.narrow(0, shard_offset, shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
# shard_id: (shard_offset, shard_size)
gate_up_offsets = {}
current_shard_offset = 0
for i, output_size in enumerate(self.output_sizes):
# Everything shrinks by tp_size if TP enabled
output_size = output_size // tp_size
gate_up_offsets[i] = (current_shard_offset, output_size)
current_shard_offset += output_size
# Re-size the param to the size after TP
if current_shard_offset != param.shape[0]:
# The clone will free the original, full tensor
param.data = param.data.narrow(0, 0, current_shard_offset).clone()
# Now load gate or up
assert loaded_shard_id < len(self.output_sizes)
param_data = param.data
shard_offset, shard_size = gate_up_offsets[loaded_shard_id]
param_data = param_data.narrow(0, shard_offset, shard_size)
loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class LlamaMLP(nn.Module):
_tp_plan = {
"gate_up_proj": "Colwise_Sharded",
"down_proj": "Rowwise",
}
def __init__(
self,
hidden_size: int,
......@@ -104,62 +142,44 @@ class LlamaMLP(nn.Module):
return x
def _get_shard_offset_mapping(self, loaded_shard_id: str):
shard_offset_mapping = {
"q": 0,
"k": self.num_heads * self.head_size,
"v": (self.num_heads + self.num_kv_heads) * self.head_size,
"total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size,
}
return shard_offset_mapping.get(loaded_shard_id)
def _get_shard_size_mapping(self, loaded_shard_id: str):
shard_size_mapping = {
"q": self.num_heads * self.head_size,
"k": self.num_kv_heads * self.head_size,
"v": self.num_kv_heads * self.head_size,
}
return shard_size_mapping.get(loaded_shard_id)
def qkv_proj_weight_loader(
self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None,
loaded_shard_id: str,
):
if loaded_shard_id is None:
shard_offsets = [
# (shard_id, shard_offset, shard_size)
("q", 0, self.total_num_heads * self.head_size),
(
"k",
self.total_num_heads * self.head_size,
self.total_num_kv_heads * self.head_size,
),
(
"v",
(self.total_num_heads + self.total_num_kv_heads) * self.head_size,
self.total_num_kv_heads * self.head_size,
),
]
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = loaded_weight.narrow(
param.output_dim, shard_offset, shard_size
)
self.weight_loader(param, loaded_weight_shard, shard_id)
else:
shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
shard_size = self._get_shard_size_mapping(loaded_shard_id)
param_data = param.data
param_data = param_data.narrow(0, shard_offset, shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
return
num_heads = self.num_heads // tp_size
num_kv_heads = self.num_kv_heads // tp_size
# shard_id: (shard_offset, shard_size)
qkv_offsets = {
"q": (0, num_heads * self.head_size),
"k": (num_heads * self.head_size, num_kv_heads * self.head_size),
"v": (
(num_heads + num_kv_heads) * self.head_size,
num_kv_heads * self.head_size,
),
}
total_size = qkv_offsets["v"][0] + qkv_offsets["v"][1]
# Re-size the param to the size after TP
if total_size != param.shape[0]:
# The clone will free the original, full tensor
param.data = param.data.narrow(0, 0, total_size).clone()
# Now load q, k or v
shard_offset, shard_size = qkv_offsets[loaded_shard_id]
param_data = param.data
param_data = param_data.narrow(0, shard_offset, shard_size)
loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size)
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
class LlamaAttention(nn.Module):
_tp_plan = {
"qkv_proj": "Colwise_Sharded",
"o_proj": "Rowwise",
}
def __init__(
self,
config: LlamaConfig,
......@@ -176,7 +196,6 @@ class LlamaAttention(nn.Module):
) -> None:
super().__init__()
self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size
......@@ -205,20 +224,12 @@ class LlamaAttention(nn.Module):
(self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim,
bias=False,
)
self.qkv_proj.total_num_heads = self.total_num_heads
self.qkv_proj.head_size = self.head_dim
self.qkv_proj.total_num_kv_heads = self.total_num_kv_heads
self.qkv_proj.num_heads = self.total_num_heads
self.qkv_proj.num_kv_heads = self.total_num_kv_heads
self.qkv_proj.weight_loader = types.MethodType(
qkv_proj_weight_loader, self.qkv_proj
)
self.qkv_proj._get_shard_offset_mapping = types.MethodType(
_get_shard_offset_mapping, self.qkv_proj
)
self.qkv_proj._get_shard_size_mapping = types.MethodType(
_get_shard_size_mapping, self.qkv_proj
)
self.qkv_proj.weight.weight_loader = self.qkv_proj.weight_loader
self.qkv_proj.weight.output_dim = 0
self.o_proj = torch.nn.Linear(
......@@ -385,6 +396,7 @@ class TorchNativeLlamaForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"]
self.supports_torch_tp = True
self.model = LlamaModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
......
import unittest
from sglang.test.test_utils import is_in_ci, run_bench_latency
class TestTorchTP(unittest.TestCase):
def test_torch_native_llama(self):
output_throughput = run_bench_latency(
"meta-llama/Meta-Llama-3-8B",
[
"--tp",
"2",
"--json-model-override-args",
'{"architectures": ["TorchNativeLlamaForCausalLM"]}',
"--disable-cuda-graph",
],
)
if is_in_ci():
assert output_throughput > 0, f"{output_throughput=}"
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment