Unverified Commit cf248976 authored by Ke Wen's avatar Ke Wen Committed by GitHub
Browse files

Add Tensor Parallel to torch_native_llama (#1876)

parent e5c67150
...@@ -220,8 +220,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len): ...@@ -220,8 +220,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
return reqs return reqs
@torch.inference_mode() def _extend(reqs, model_runner):
def extend(reqs, model_runner):
batch = ScheduleBatch.init_new( batch = ScheduleBatch.init_new(
reqs=reqs, reqs=reqs,
req_to_token_pool=model_runner.req_to_token_pool, req_to_token_pool=model_runner.req_to_token_pool,
...@@ -237,8 +236,15 @@ def extend(reqs, model_runner): ...@@ -237,8 +236,15 @@ def extend(reqs, model_runner):
return next_token_ids, logits_output.next_token_logits, batch return next_token_ids, logits_output.next_token_logits, batch
@torch.inference_mode() def extend(reqs, model_runner):
def decode(input_token_ids, batch, model_runner): # Disable inference mode for now when torch TP is applied. We can remove
# this workaround once DTensor adds support for inference mode.
use_inf_mode = not model_runner.torch_tp_applied
with torch.inference_mode(use_inf_mode):
return _extend(reqs, model_runner)
def _decode(input_token_ids, batch, model_runner):
batch.output_ids = input_token_ids batch.output_ids = input_token_ids
batch.prepare_for_decode() batch.prepare_for_decode()
model_worker_batch = batch.get_model_worker_batch() model_worker_batch = batch.get_model_worker_batch()
...@@ -248,6 +254,14 @@ def decode(input_token_ids, batch, model_runner): ...@@ -248,6 +254,14 @@ def decode(input_token_ids, batch, model_runner):
return next_token_ids, logits_output.next_token_logits return next_token_ids, logits_output.next_token_logits
def decode(input_token_ids, batch, model_runner):
# Disable inference mode for now when torch TP is applied. We can remove
# this workaround once DTensor adds support for inference mode.
use_inf_mode = not model_runner.torch_tp_applied
with torch.inference_mode(use_inf_mode):
return _decode(input_token_ids, batch, model_runner)
def correctness_test( def correctness_test(
server_args, server_args,
port_args, port_args,
......
...@@ -148,6 +148,15 @@ class ModelRunner: ...@@ -148,6 +148,15 @@ class ModelRunner:
min_per_gpu_memory = self.init_torch_distributed() min_per_gpu_memory = self.init_torch_distributed()
self.sampler = Sampler() self.sampler = Sampler()
self.load_model() self.load_model()
# Apply torch TP if model supports it
supports_torch_tp = getattr(self.model, "supports_torch_tp", False)
if self.tp_size > 1 and supports_torch_tp:
self.apply_torch_tp()
self.torch_tp_applied = True
else:
self.torch_tp_applied = False
if server_args.lora_paths is not None: if server_args.lora_paths is not None:
self.init_lora_manager() self.init_lora_manager()
self.init_memory_pool( self.init_memory_pool(
...@@ -551,6 +560,13 @@ class ModelRunner: ...@@ -551,6 +560,13 @@ class ModelRunner:
logger.info("Capture cuda graph begin. This can take up to several minutes.") logger.info("Capture cuda graph begin. This can take up to several minutes.")
self.cuda_graph_runner = CudaGraphRunner(self) self.cuda_graph_runner = CudaGraphRunner(self)
def apply_torch_tp(self):
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
from sglang.srt.model_parallel import tensor_parallel
device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
tensor_parallel(self.model, device_mesh)
def forward_decode(self, forward_batch: ForwardBatch): def forward_decode(self, forward_batch: ForwardBatch):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch): if self.cuda_graph_runner and self.cuda_graph_runner.can_run(forward_batch):
return self.cuda_graph_runner.replay(forward_batch) return self.cuda_graph_runner.replay(forward_batch)
......
"""
Common utilities for torch model parallelism.
"""
from typing import Optional, Sequence
import torch
from torch.distributed.device_mesh import DeviceMesh
try:
from torch.distributed.tensor import DTensor, Shard
except ImportError:
# torch 2.4 or older
from torch.distributed._tensor import DTensor, Shard
from torch.distributed._functional_collectives import AsyncCollectiveTensor
from torch.distributed.tensor.parallel import (
ColwiseParallel,
RowwiseParallel,
parallelize_module,
)
class ColwiseParallelSharded(ColwiseParallel):
"""
A version of ColwiseParallel where the local weight has been already
sharded. This is used for the fused wqkv case, where during loading, we
already sharded wq, wk, wv before fusing them.
"""
# Override the _partition_linear_fn in ColwiseParallel
def _partition_linear_fn(self, name, module, device_mesh):
# colwise shard weight/bias to Shard(0), weight be Shard(0)
# means Colwise as Linear is input * weight^T + bias, where
# weight would become Shard(1)
for name, param in module.named_parameters():
dtensor = DTensor.from_local(param, device_mesh, [Shard(0)])
dist_param = torch.nn.Parameter(dtensor, requires_grad=False)
module.register_parameter(name, dist_param)
class RowwiseParallelMaybeWait(RowwiseParallel):
"""
A version of RowwiseParallel that waits for the output (establish dependency
between comm stream and compute stream in CUDA sense) before going into the
next op. This is needed to workaround the current interaction between
AsyncCollectiveTensor and custom ops, such as `class RMSNorm(CustomOp)`.
"""
@staticmethod
def _prepare_output_fn(output_layouts, use_local_output, mod, outputs, device_mesh):
outputs = super(
RowwiseParallelMaybeWait, RowwiseParallelMaybeWait
)._prepare_output_fn(
output_layouts, use_local_output, mod, outputs, device_mesh
)
# wait for the output to be ready
if isinstance(outputs, AsyncCollectiveTensor):
return outputs.wait()
else:
return outputs
def tensor_parallel(
module: torch.nn.Module,
device_mesh: Optional[DeviceMesh] = None,
):
"""
Tensor parallelize the model across the given device mesh.
Args:
module (`torch.nn.Module`):
The module to tensor parallelize.
device_mesh (`torch.distributed.DeviceMesh`):
The device mesh to use for tensor parallelism.
"""
# Tensor parallelize a nn.Module based on the `_tp_plan` attribute of the module.
# No op if `_tp_plan` attribute does not exist under the module.
# This is a helper function to be used with `model.apply` to recursively
# parallelize a model.
def tplize(mod: torch.nn.Module) -> None:
tp_plan = getattr(mod, "_tp_plan", None)
if tp_plan is None:
return
for child_name, tp_style in tp_plan.items():
submod = mod.get_submodule(child_name)
if tp_style == "Colwise":
parallelize_module(submod, device_mesh, ColwiseParallel())
elif tp_style == "Rowwise":
parallelize_module(submod, device_mesh, RowwiseParallelMaybeWait())
elif tp_style == "Colwise_Sharded":
parallelize_module(submod, device_mesh, ColwiseParallelSharded())
else:
raise ValueError(f"Unknown TP style {tp_style}")
# `apply` is a native method of `nn.Module` that recursively applies a
# function to every submodule.
module.apply(tplize)
...@@ -17,6 +17,31 @@ limitations under the License. ...@@ -17,6 +17,31 @@ limitations under the License.
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1 # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
"""Inference-only LLaMA model compatible with HuggingFace weights.""" """Inference-only LLaMA model compatible with HuggingFace weights."""
# PyTorch Tensor Parallel Available for This Model
"""
This model supports tensor parallelism (TP) using the PyTorch tensor parallel package.
Reference: https://pytorch.org/docs/stable/distributed.tensor.parallel.html
Here is a quick example to enable TP:
```python
from sglang.srt.model_parallel import tensor_parallel
device_mesh = torch.distributed.init_device_mesh("cuda", (tp_size,))
tensor_parallel(model, device_mesh)
```
An end-to-end example can be found in `python/sglang/bench_latency.py`.
You can run it with the following command:
```bash
$ python3 -m sglang.bench_latency --correct \
--model meta-llama/Meta-Llama-3-8B \
--json-model-override-args '{"architectures": ["TorchNativeLlamaForCausalLM"]}' \
--tensor-parallel-size 2 \
--disable-cuda-graph
```
We will eanble CUDA Graph support soon.
"""
import types import types
from typing import Any, Dict, Iterable, Optional, Tuple from typing import Any, Dict, Iterable, Optional, Tuple
...@@ -24,7 +49,10 @@ import torch ...@@ -24,7 +49,10 @@ import torch
from torch import nn from torch import nn
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
...@@ -41,35 +69,45 @@ from sglang.srt.layers.vocab_parallel_embedding import ( ...@@ -41,35 +69,45 @@ from sglang.srt.layers.vocab_parallel_embedding import (
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch from sglang.srt.model_executor.forward_batch_info import ForwardBatch
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
def gate_up_proj_weight_loader( def gate_up_proj_weight_loader(
self, self,
param: Parameter, param: Parameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None, loaded_shard_id: int,
): ):
if loaded_shard_id is None: # shard_id: (shard_offset, shard_size)
shard_offsets: List[Tuple[int, int, int]] = [] gate_up_offsets = {}
current_shard_offset = 0
for i, output_size in enumerate(self.output_sizes): for i, output_size in enumerate(self.output_sizes):
shard_offsets.append((i, current_shard_offset, output_size)) # Everything shrinks by tp_size if TP enabled
output_size = output_size // tp_size
gate_up_offsets[i] = (current_shard_offset, output_size)
current_shard_offset += output_size current_shard_offset += output_size
for shard_id, shard_offset, shard_size in shard_offsets: # Re-size the param to the size after TP
loaded_weight_shard = loaded_weight.narrow( if current_shard_offset != param.shape[0]:
output_dim, shard_offset, shard_size # The clone will free the original, full tensor
) param.data = param.data.narrow(0, 0, current_shard_offset).clone()
self.weight_loader(param, loaded_weight_shard, shard_id)
else: # Now load gate or up
assert loaded_shard_id < len(self.output_sizes) assert loaded_shard_id < len(self.output_sizes)
param_data = param.data param_data = param.data
shard_size = loaded_weight.shape[0] shard_offset, shard_size = gate_up_offsets[loaded_shard_id]
shard_offset = loaded_shard_id * shard_size
param_data = param_data.narrow(0, shard_offset, shard_size) param_data = param_data.narrow(0, shard_offset, shard_size)
loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size)
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
return
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
_tp_plan = {
"gate_up_proj": "Colwise_Sharded",
"down_proj": "Rowwise",
}
def __init__( def __init__(
self, self,
hidden_size: int, hidden_size: int,
...@@ -104,62 +142,44 @@ class LlamaMLP(nn.Module): ...@@ -104,62 +142,44 @@ class LlamaMLP(nn.Module):
return x return x
def _get_shard_offset_mapping(self, loaded_shard_id: str):
shard_offset_mapping = {
"q": 0,
"k": self.num_heads * self.head_size,
"v": (self.num_heads + self.num_kv_heads) * self.head_size,
"total": (self.num_heads + 2 * self.num_kv_heads) * self.head_size,
}
return shard_offset_mapping.get(loaded_shard_id)
def _get_shard_size_mapping(self, loaded_shard_id: str):
shard_size_mapping = {
"q": self.num_heads * self.head_size,
"k": self.num_kv_heads * self.head_size,
"v": self.num_kv_heads * self.head_size,
}
return shard_size_mapping.get(loaded_shard_id)
def qkv_proj_weight_loader( def qkv_proj_weight_loader(
self, self,
param: Parameter, param: Parameter,
loaded_weight: torch.Tensor, loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None, loaded_shard_id: str,
): ):
if loaded_shard_id is None: num_heads = self.num_heads // tp_size
shard_offsets = [ num_kv_heads = self.num_kv_heads // tp_size
# (shard_id, shard_offset, shard_size) # shard_id: (shard_offset, shard_size)
("q", 0, self.total_num_heads * self.head_size), qkv_offsets = {
( "q": (0, num_heads * self.head_size),
"k", "k": (num_heads * self.head_size, num_kv_heads * self.head_size),
self.total_num_heads * self.head_size, "v": (
self.total_num_kv_heads * self.head_size, (num_heads + num_kv_heads) * self.head_size,
num_kv_heads * self.head_size,
), ),
( }
"v", total_size = qkv_offsets["v"][0] + qkv_offsets["v"][1]
(self.total_num_heads + self.total_num_kv_heads) * self.head_size, # Re-size the param to the size after TP
self.total_num_kv_heads * self.head_size, if total_size != param.shape[0]:
), # The clone will free the original, full tensor
] param.data = param.data.narrow(0, 0, total_size).clone()
for shard_id, shard_offset, shard_size in shard_offsets:
loaded_weight_shard = loaded_weight.narrow( # Now load q, k or v
param.output_dim, shard_offset, shard_size shard_offset, shard_size = qkv_offsets[loaded_shard_id]
)
self.weight_loader(param, loaded_weight_shard, shard_id)
else:
shard_offset = self._get_shard_offset_mapping(loaded_shard_id)
shard_size = self._get_shard_size_mapping(loaded_shard_id)
param_data = param.data param_data = param.data
param_data = param_data.narrow(0, shard_offset, shard_size) param_data = param_data.narrow(0, shard_offset, shard_size)
loaded_weight = loaded_weight.narrow(0, tp_rank * shard_size, shard_size)
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
return
class LlamaAttention(nn.Module): class LlamaAttention(nn.Module):
_tp_plan = {
"qkv_proj": "Colwise_Sharded",
"o_proj": "Rowwise",
}
def __init__( def __init__(
self, self,
config: LlamaConfig, config: LlamaConfig,
...@@ -176,7 +196,6 @@ class LlamaAttention(nn.Module): ...@@ -176,7 +196,6 @@ class LlamaAttention(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.hidden_size = hidden_size self.hidden_size = hidden_size
tp_size = get_tensor_model_parallel_world_size()
self.total_num_heads = num_heads self.total_num_heads = num_heads
assert self.total_num_heads % tp_size == 0 assert self.total_num_heads % tp_size == 0
self.num_heads = self.total_num_heads // tp_size self.num_heads = self.total_num_heads // tp_size
...@@ -205,20 +224,12 @@ class LlamaAttention(nn.Module): ...@@ -205,20 +224,12 @@ class LlamaAttention(nn.Module):
(self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim, (self.total_num_heads + 2 * self.total_num_kv_heads) * self.head_dim,
bias=False, bias=False,
) )
self.qkv_proj.total_num_heads = self.total_num_heads
self.qkv_proj.head_size = self.head_dim self.qkv_proj.head_size = self.head_dim
self.qkv_proj.total_num_kv_heads = self.total_num_kv_heads
self.qkv_proj.num_heads = self.total_num_heads self.qkv_proj.num_heads = self.total_num_heads
self.qkv_proj.num_kv_heads = self.total_num_kv_heads self.qkv_proj.num_kv_heads = self.total_num_kv_heads
self.qkv_proj.weight_loader = types.MethodType( self.qkv_proj.weight_loader = types.MethodType(
qkv_proj_weight_loader, self.qkv_proj qkv_proj_weight_loader, self.qkv_proj
) )
self.qkv_proj._get_shard_offset_mapping = types.MethodType(
_get_shard_offset_mapping, self.qkv_proj
)
self.qkv_proj._get_shard_size_mapping = types.MethodType(
_get_shard_size_mapping, self.qkv_proj
)
self.qkv_proj.weight.weight_loader = self.qkv_proj.weight_loader self.qkv_proj.weight.weight_loader = self.qkv_proj.weight_loader
self.qkv_proj.weight.output_dim = 0 self.qkv_proj.weight.output_dim = 0
self.o_proj = torch.nn.Linear( self.o_proj = torch.nn.Linear(
...@@ -385,6 +396,7 @@ class TorchNativeLlamaForCausalLM(nn.Module): ...@@ -385,6 +396,7 @@ class TorchNativeLlamaForCausalLM(nn.Module):
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.torchao_config = global_server_args_dict["torchao_config"] self.torchao_config = global_server_args_dict["torchao_config"]
self.supports_torch_tp = True
self.model = LlamaModel(config, quant_config=quant_config) self.model = LlamaModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
......
import unittest
from sglang.test.test_utils import is_in_ci, run_bench_latency
class TestTorchTP(unittest.TestCase):
def test_torch_native_llama(self):
output_throughput = run_bench_latency(
"meta-llama/Meta-Llama-3-8B",
[
"--tp",
"2",
"--json-model-override-args",
'{"architectures": ["TorchNativeLlamaForCausalLM"]}',
"--disable-cuda-graph",
],
)
if is_in_ci():
assert output_throughput > 0, f"{output_throughput=}"
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment