"src/vscode:/vscode.git/clone" did not exist on "7c1b347702c17f282f5da4d41d68e9fdeb7908fa"
Unverified Commit caaad53b authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Support gpt-bigcode model class (#681)

parent 69d19188
......@@ -34,12 +34,11 @@ class LogitProcessorOutput:
@dataclasses.dataclass
class LogitsMetadata:
forward_mode: ForwardMode
extend_seq_lens: torch.Tensor
extend_start_loc: torch.Tensor
# For logprobs
return_logprob: bool
top_logprobs_nums: List[int]
extend_seq_lens: torch.Tensor = None
extend_start_loc: torch.Tensor = None
top_logprobs_nums: List[int] = None
@classmethod
def from_input_metadata(cls, input_metadata: InputMetadata):
......
......@@ -6,6 +6,7 @@ import torch
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
from vllm.distributed.parallel_state import graph_capture
from vllm.model_executor.custom_op import CustomOp
from sglang.srt.layers.logits_processor import LogitProcessorOutput
from sglang.srt.managers.controller.infer_batch import (
......@@ -16,8 +17,28 @@ from sglang.srt.managers.controller.infer_batch import (
)
def _to_torch(model: torch.nn.Module, reverse=False):
for sub in model._modules.values():
if isinstance(sub, CustomOp):
if reverse:
sub._forward_method = sub.forward_cuda
else:
sub._forward_method = sub.forward_native
if isinstance(sub, torch.nn.Module):
_to_torch(sub, reverse)
def get_forward(model: torch.nn.Module, use_torch: bool):
if use_torch:
_to_torch(model, reverse=False)
return torch.compile(model.forward, mode="max-autotune-no-cudagraphs")
else:
_to_torch(model, reverse=True)
return model.forward
class CudaGraphRunner:
def __init__(self, model_runner, max_batch_size_to_capture):
def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile):
self.model_runner = model_runner
self.graphs = {}
self.input_buffers = {}
......@@ -55,6 +76,8 @@ class CudaGraphRunner:
(self.max_bs,), dtype=torch.int32, device="cuda"
)
self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
def can_run(self, batch_size):
return batch_size < self.max_bs
......@@ -63,18 +86,19 @@ class CudaGraphRunner:
with graph_capture() as graph_capture_context:
self.stream = graph_capture_context.stream
for bs in batch_size_list:
forward = get_forward(self.model_runner.model, bs in self.compile_bs)
(
graph,
input_buffers,
output_buffers,
flashinfer_handler,
) = self.capture_one_batch_size(bs)
) = self.capture_one_batch_size(bs, forward)
self.graphs[bs] = graph
self.input_buffers[bs] = input_buffers
self.output_buffers[bs] = output_buffers
self.flashinfer_handlers[bs] = flashinfer_handler
def capture_one_batch_size(self, bs):
def capture_one_batch_size(self, bs, forward):
graph = torch.cuda.CUDAGraph()
stream = self.stream
......@@ -127,9 +151,8 @@ class CudaGraphRunner:
skip_flashinfer_init=True,
)
input_metadata.flashinfer_decode_wrapper = flashinfer_decode_wrapper
return self.model_runner.model.forward(
input_ids, input_metadata.positions, input_metadata
)
return forward(input_ids, input_metadata.positions, input_metadata)
for _ in range(2):
run_once()
......
......@@ -244,7 +244,9 @@ class ModelRunner:
logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.")
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 17)]
self.cuda_graph_runner = CudaGraphRunner(
self, max_batch_size_to_capture=max(batch_size_list)
self,
max_batch_size_to_capture=max(batch_size_list),
use_torch_compile=self.server_args.enable_torch_compile,
)
try:
self.cuda_graph_runner.capture(batch_size_list)
......
# Adapted from:
# https://github.com/vllm-project/vllm/blob/07eb6f19f3b0ee9f7adf6eb689607028aa40bfd5/vllm/model_executor/models/gpt_bigcode.py
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
from typing import Iterable, Optional, Tuple
import torch
from torch import nn
from transformers import GPTBigCodeConfig
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.infer_batch import InputMetadata
class GPTBigCodeAttention(nn.Module):
def __init__(
self,
layer_id: int,
config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
self.hidden_size = config.hidden_size
total_num_heads = config.num_attention_heads
self.tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
assert total_num_heads % self.tensor_model_parallel_world_size == 0
self.num_heads = total_num_heads // self.tensor_model_parallel_world_size
self.head_dim = self.hidden_size // total_num_heads
self.scale = self.head_dim**-0.5
self.multi_query = config.multi_query
if self.multi_query:
total_num_kv_heads = 1
self.num_kv_heads = 1
else:
total_num_kv_heads = total_num_heads
self.num_kv_heads = self.num_heads
self.kv_dim = self.head_dim * self.num_kv_heads
self.c_attn = QKVParallelLinear(
self.hidden_size,
self.head_dim,
total_num_heads,
total_num_kv_heads,
bias=True,
quant_config=quant_config,
)
self.c_proj = RowParallelLinear(
self.hidden_size,
self.hidden_size,
bias=True,
quant_config=quant_config,
)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
scaling=self.scale,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
)
def forward(
self,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.split(
[
self.hidden_size // self.tensor_model_parallel_world_size,
self.kv_dim,
self.kv_dim,
],
dim=-1,
)
attn_output = self.attn(q, k, v, input_metadata)
attn_output, _ = self.c_proj(attn_output)
return attn_output
class GPTBigMLP(nn.Module):
def __init__(
self,
intermediate_size: int,
config: GPTBigCodeConfig,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
self.c_fc = ColumnParallelLinear(
hidden_size,
intermediate_size,
bias=True,
quant_config=quant_config,
)
self.c_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=True,
quant_config=quant_config,
)
self.act = get_act_fn(
config.activation_function, quant_config, intermediate_size
)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
hidden_states, _ = self.c_fc(hidden_states)
hidden_states = self.act(hidden_states)
hidden_states, _ = self.c_proj(hidden_states)
return hidden_states
class GPTBigCodeBlock(nn.Module):
def __init__(
self,
layer_id: int,
config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__()
hidden_size = config.hidden_size
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.attn = GPTBigCodeAttention(layer_id, config, cache_config, quant_config)
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
self.mlp = GPTBigMLP(inner_dim, config, quant_config)
def forward(
self,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_output = self.attn(
hidden_states=hidden_states, input_metadata=input_metadata
)
# residual connection
hidden_states = attn_output + residual
residual = hidden_states
hidden_states = self.ln_2(hidden_states)
feed_forward_hidden_states = self.mlp(hidden_states)
# residual connection
hidden_states = residual + feed_forward_hidden_states
return hidden_states
class GPTBigCodeModel(nn.Module):
def __init__(
self,
config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config = config
assert not config.add_cross_attention
self.embed_dim = config.hidden_size
lora_vocab = (
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
if lora_config
else 0
)
self.vocab_size = config.vocab_size + lora_vocab
self.wte = VocabParallelEmbedding(
self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size
)
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
self.h = nn.ModuleList(
[
GPTBigCodeBlock(i, config, cache_config, quant_config)
for i in range(config.num_hidden_layers)
]
)
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
def forward(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
hidden_states = inputs_embeds + position_embeds
for i in range(len(self.h)):
layer = self.h[i]
hidden_states = layer(hidden_states, input_metadata)
hidden_states = self.ln_f(hidden_states)
return hidden_states
class GPTBigCodeForCausalLM(nn.Module):
packed_modules_mapping = {"c_attn": ["c_attn"]}
supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"]
embedding_modules = {
"wte": "input_embeddings",
"lm_head": "output_embeddings",
}
embedding_padding_modules = []
def __init__(
self,
config: GPTBigCodeConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
lora_config: Optional[LoRAConfig] = None,
):
super().__init__()
self.config = config
self.lora_config = lora_config
self.quant_config = quant_config
self.transformer = GPTBigCodeModel(
config, cache_config, quant_config, lora_config
)
self.lm_head = self.transformer.wte
self.unpadded_vocab_size = config.vocab_size
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.logits_processor = LogitsProcessor(config)
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, input_metadata)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in weights:
if "lm_head.weight" in name:
continue
if ".attn.bias" in name:
# Skip attention mask.
# NOTE: "c_attn.bias" should not be skipped.
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
# TODO (@robertgshaw2-neuralmagic): move to fp8 linear method
if "c_attn.input_scale" in name or "c_attn.weight_scale" in name:
weight_loader(param, loaded_weight, "q")
weight_loader(param, loaded_weight, "k")
weight_loader(param, loaded_weight, "v")
else:
weight_loader(param, loaded_weight)
EntryClass = GPTBigCodeForCausalLM
......@@ -157,6 +157,19 @@ def _set_global_server_args(server_args: ServerArgs):
}
def _set_torch_compile_config():
# The following configurations are for torch compile optimizations
import torch._dynamo.config
import torch._inductor.config
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
# FIXME: tmp workaround
torch._dynamo.config.accumulated_cache_size_limit = 128
def launch_server(
server_args: ServerArgs,
model_overide_args: Optional[dict] = None,
......@@ -190,6 +203,10 @@ def launch_server(
if server_args.chat_template:
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api(server_args.chat_template)
if server_args.enable_torch_compile:
_set_torch_compile_config()
_set_global_server_args(server_args)
# Allocate ports
......
......@@ -55,6 +55,7 @@ class ServerArgs:
disable_regex_jump_forward: bool = False
disable_cuda_graph: bool = False
disable_disk_cache: bool = False
enable_torch_compile: bool = False
attention_reduce_in_fp32: bool = False
enable_p2p_check: bool = False
efficient_weight_load: bool = False
......@@ -317,6 +318,11 @@ class ServerArgs:
action="store_true",
help="Disable disk cache to avoid possible crashes related to file system or high concurrency.",
)
parser.add_argument(
"--enable-torch-compile",
action="store_true",
help="Optimize the model with torch.compile, experimental feature.",
)
parser.add_argument(
"--attention-reduce-in-fp32",
action="store_true",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment