Unverified Commit 69672f11 authored by youkaichao's avatar youkaichao Committed by GitHub
Browse files

[core][distributed] simplify code to support pipeline parallel (#6406)

parent 44874a0b
...@@ -46,9 +46,7 @@ steps: ...@@ -46,9 +46,7 @@ steps:
fast_check: true fast_check: true
commands: commands:
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl - pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.8/flashinfer-0.0.8+cu121torch2.3-cp310-cp310-linux_x86_64.whl
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_basic_correctness.py - pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=FLASHINFER pytest -v -s basic_correctness/test_basic_correctness.py
- VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py
- VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py
......
...@@ -28,10 +28,8 @@ def test_vllm_gc_ed(): ...@@ -28,10 +28,8 @@ def test_vllm_gc_ed():
assert weak_llm() is None assert weak_llm() is None
@pytest.mark.skipif(is_hip()
and os.getenv("VLLM_ATTENTION_BACKEND") == "FLASHINFER",
reason="Flashinfer does not support ROCm/HIP.")
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("backend", ["FLASH_ATTN", "XFORMERS", "FLASHINFER"])
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
@pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("max_tokens", [5])
@pytest.mark.parametrize("enforce_eager", [False, True]) @pytest.mark.parametrize("enforce_eager", [False, True])
...@@ -40,10 +38,17 @@ def test_models( ...@@ -40,10 +38,17 @@ def test_models(
vllm_runner, vllm_runner,
example_prompts, example_prompts,
model: str, model: str,
backend: str,
dtype: str, dtype: str,
max_tokens: int, max_tokens: int,
enforce_eager: bool, enforce_eager: bool,
) -> None: ) -> None:
if backend == "FLASHINFER" and is_hip():
pytest.skip("Flashinfer does not support ROCm/HIP.")
os.environ["VLLM_ATTENTION_BACKEND"] = backend
with hf_runner(model, dtype=dtype) as hf_model: with hf_runner(model, dtype=dtype) as hf_model:
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
......
...@@ -27,7 +27,6 @@ from vllm.attention import Attention, AttentionMetadata ...@@ -27,7 +27,6 @@ from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed.parallel_state import ( from vllm.distributed.parallel_state import (
get_pp_group, get_tensor_model_parallel_world_size) get_pp_group, get_tensor_model_parallel_world_size)
from vllm.distributed.utils import get_pp_indices
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
QKVParallelLinear, QKVParallelLinear,
...@@ -42,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -42,6 +41,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput from vllm.sequence import IntermediateTensors, SamplerOutput
from .utils import is_pp_missing_parameter, make_layers
class GPT2Attention(nn.Module): class GPT2Attention(nn.Module):
...@@ -183,18 +184,9 @@ class GPT2Model(nn.Module): ...@@ -183,18 +184,9 @@ class GPT2Model(nn.Module):
self.embed_dim = config.hidden_size self.embed_dim = config.hidden_size
self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim) self.wte = VocabParallelEmbedding(config.vocab_size, self.embed_dim)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.start_layer, self.end_layer = get_pp_indices( self.start_layer, self.end_layer, self.h = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
get_pp_group().rank_in_group, lambda: GPT2Block(config, cache_config, quant_config))
get_pp_group().world_size)
self.h = nn.ModuleList(
[nn.Identity() for _ in range(self.start_layer)] + [
GPT2Block(config, cache_config, quant_config)
for _ in range(self.start_layer, self.end_layer)
] + [
nn.Identity()
for _ in range(self.end_layer, config.num_hidden_layers)
])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward( def forward(
...@@ -291,19 +283,20 @@ class GPT2LMHeadModel(nn.Module): ...@@ -291,19 +283,20 @@ class GPT2LMHeadModel(nn.Module):
continue continue
if not name.startswith("transformer."): if not name.startswith("transformer."):
name = "transformer." + name name = "transformer." + name
try:
param = params_dict[name] if is_pp_missing_parameter(name, self):
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
if conv1d_weight_name not in name:
continue
if not name.endswith(".weight"):
continue
loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
except KeyError:
continue continue
param = params_dict[name]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
# Note(zhuohan): the logic below might break quantized models.
for conv1d_weight_name in ["c_attn", "c_proj", "c_fc"]:
if conv1d_weight_name not in name:
continue
if not name.endswith(".weight"):
continue
loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)
...@@ -29,8 +29,7 @@ from transformers import LlamaConfig ...@@ -29,8 +29,7 @@ from transformers import LlamaConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_pp_group, get_pp_indices, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.layernorm import RMSNorm
...@@ -51,6 +50,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput ...@@ -51,6 +50,7 @@ from vllm.sequence import IntermediateTensors, SamplerOutput
from vllm.utils import is_hip, print_warning_once from vllm.utils import is_hip, print_warning_once
from .interfaces import SupportsLoRA from .interfaces import SupportsLoRA
from .utils import is_pp_missing_parameter, make_layers
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
...@@ -262,20 +262,11 @@ class LlamaModel(nn.Module): ...@@ -262,20 +262,11 @@ class LlamaModel(nn.Module):
config.hidden_size, config.hidden_size,
org_num_embeddings=config.vocab_size, org_num_embeddings=config.vocab_size,
) )
self.start_layer, self.end_layer = get_pp_indices( self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers, config.num_hidden_layers,
get_pp_group().rank_in_group, lambda: LlamaDecoderLayer(config=config,
get_pp_group().world_size) cache_config=cache_config,
self.layers = nn.ModuleList( quant_config=quant_config))
[nn.Identity() for _ in range(self.start_layer)] + [
LlamaDecoderLayer(config=config,
cache_config=cache_config,
quant_config=quant_config)
for _ in range(self.start_layer, self.end_layer)
] + [
nn.Identity()
for _ in range(self.end_layer, config.num_hidden_layers)
])
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
...@@ -455,12 +446,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -455,12 +446,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
try:
param = params_dict[name] if is_pp_missing_parameter(name, self):
weight_loader = param.weight_loader continue
weight_loader(param, loaded_weight, shard_id)
except KeyError: param = params_dict[name]
pass weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Skip loading extra bias for GPTQ models. # Skip loading extra bias for GPTQ models.
...@@ -479,13 +472,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -479,13 +472,14 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
continue continue
else: else:
name = remapped_kv_scale_name name = remapped_kv_scale_name
try:
param = params_dict[name] if is_pp_missing_parameter(name, self):
weight_loader = getattr(param, "weight_loader", continue
default_weight_loader)
weight_loader(param, loaded_weight) param = params_dict[name]
except KeyError: weight_loader = getattr(param, "weight_loader",
pass default_weight_loader)
weight_loader(param, loaded_weight)
# If this function is called, it should always initialize KV cache scale # If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should # factors (or else raise an exception). Thus, handled exceptions should
......
from typing import Callable, Dict, List, Tuple
import torch import torch
from vllm.multimodal import BatchedTensors from vllm.multimodal import BatchedTensors
...@@ -39,3 +41,57 @@ def merge_vision_embeddings(input_ids: torch.Tensor, ...@@ -39,3 +41,57 @@ def merge_vision_embeddings(input_ids: torch.Tensor,
inputs_embeds[mask] = torch.cat(vision_embeddings) inputs_embeds[mask] = torch.cat(vision_embeddings)
return inputs_embeds return inputs_embeds
class PPMissingLayer(torch.nn.Identity):
"""
A placeholder layer for missing layers in a pipeline parallel model.
"""
def __init__(self, *args, **kwargs):
super().__init__()
def make_layers(
num_hidden_layers: int, layer_fn: Callable[[], torch.nn.Module]
) -> Tuple[int, int, torch.nn.ModuleList]:
"""Make a list of layers with the given layer function, taking
pipeline parallelism into account.
"""
from vllm.distributed.parallel_state import get_pp_group
from vllm.distributed.utils import get_pp_indices
start_layer, end_layer = get_pp_indices(num_hidden_layers,
get_pp_group().rank_in_group,
get_pp_group().world_size)
modules = torch.nn.ModuleList(
[PPMissingLayer() for _ in range(start_layer)] +
[layer_fn() for _ in range(start_layer, end_layer)] +
[PPMissingLayer() for _ in range(end_layer, num_hidden_layers)])
return start_layer, end_layer, modules
# NOTE: don't use lru_cache here because it can prevent garbage collection
_model_to_pp_missing_layer_names: Dict[int, List[str]] = {}
def get_pp_missing_layer_names(model: torch.nn.Module) -> List[str]:
"""Get the names of the missing layers in a pipeline parallel model."""
model_id = id(model)
if model_id in _model_to_pp_missing_layer_names:
return _model_to_pp_missing_layer_names[model_id]
missing_layer_names = []
for name, module in model.named_modules():
if isinstance(module, PPMissingLayer):
missing_layer_names.append(name)
_model_to_pp_missing_layer_names[model_id] = missing_layer_names
return missing_layer_names
def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool:
"""Check if a parameter is missing in a pipeline parallel model."""
for missing_layer_name in get_pp_missing_layer_names(model):
if name.startswith(missing_layer_name):
return True
return False
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment