Unverified Commit 75ce37f4 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Move sampler into CUDA graph (#1201)


Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent 97589a60
......@@ -41,6 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
......@@ -299,6 +300,7 @@ class MixtralForCausalLM(nn.Module):
self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
def forward(
self,
......@@ -308,9 +310,11 @@ class MixtralForCausalLM(nn.Module):
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor(
logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......
......@@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
......@@ -333,6 +334,7 @@ class QuantMixtralForCausalLM(nn.Module):
self.model = MixtralModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
......@@ -343,9 +345,11 @@ class QuantMixtralForCausalLM(nn.Module):
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor(
logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......
......@@ -39,6 +39,7 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
......@@ -251,6 +252,7 @@ class QWenLMHeadModel(nn.Module):
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
......@@ -260,10 +262,11 @@ class QWenLMHeadModel(nn.Module):
input_metadata: InputMetadata,
):
hidden_states = self.transformer(input_ids, positions, input_metadata)
next_tokens = self.logits_processor(
logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
return next_tokens
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......
......@@ -38,8 +38,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
Qwen2Config = None
......@@ -276,6 +277,7 @@ class Qwen2ForCausalLM(nn.Module):
self.model = Qwen2Model(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@torch.no_grad()
......@@ -289,9 +291,11 @@ class Qwen2ForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
if not get_embedding:
return self.logits_processor(
logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
else:
return self.pooler(hidden_states, input_metadata)
......
......@@ -35,10 +35,8 @@ from vllm.model_executor.layers.linear import (
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
......@@ -49,6 +47,7 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
......@@ -366,6 +365,7 @@ class Qwen2MoeForCausalLM(nn.Module):
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
......@@ -376,20 +376,11 @@ class Qwen2MoeForCausalLM(nn.Module):
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor(
logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
def compute_logits(
self,
input_ids: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
logits = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
return logits
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......
......@@ -40,6 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
......@@ -249,6 +250,7 @@ class StableLmForCausalLM(nn.Module):
self.model = StableLMEpochModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
......@@ -259,9 +261,11 @@ class StableLmForCausalLM(nn.Module):
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor(
logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......
......@@ -21,10 +21,63 @@ class SamplingBatchInfo:
top_ps: torch.Tensor = None
top_ks: torch.Tensor = None
min_ps: torch.Tensor = None
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
# Dispatch in CUDA graph
need_min_p_sampling: bool = False
# Bias Tensors
logit_bias: torch.Tensor = None
vocab_mask: torch.Tensor = None
# Penalizer
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
linear_penalties: torch.Tensor = None
scaling_penalties: torch.Tensor = None
def has_bias(self):
return (
self.logit_bias is not None
or self.vocab_mask is not None
or self.linear_penalties is not None
or self.scaling_penalties is not None
)
@classmethod
def dummy_one(cls, max_bs: int, vocab_size: int):
ret = cls(vocab_size=vocab_size)
ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float, device="cuda")
ret.top_ps = torch.ones((max_bs,), dtype=torch.float, device="cuda")
ret.top_ks = torch.ones((max_bs,), dtype=torch.int, device="cuda")
ret.min_ps = torch.zeros((max_bs,), dtype=torch.float, device="cuda")
return ret
def __getitem__(self, key):
if isinstance(key, slice):
# NOTE: We do not use cuda graph when there is bias tensors
assert not self.has_bias()
return SamplingBatchInfo(
vocab_size=self.vocab_size,
temperatures=self.temperatures[key],
top_ps=self.top_ps[key],
top_ks=self.top_ks[key],
min_ps=self.min_ps[key],
need_min_p_sampling=self.need_min_p_sampling,
)
else:
raise NotImplementedError
def inplace_assign(self, bs: int, other: SamplingBatchInfo):
# NOTE: We do not use cuda graph when there is bias tensors
assert not self.has_bias()
self.vocab_size = other.vocab_size
self.need_min_p_sampling = other.need_min_p_sampling
self.temperatures[:bs] = other.temperatures
self.top_ps[:bs] = other.top_ps
self.top_ks[:bs] = other.top_ks
self.min_ps[:bs] = other.min_ps
@classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
device = "cuda"
......@@ -45,6 +98,7 @@ class SamplingBatchInfo:
ret.min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
)
ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
......@@ -72,6 +126,25 @@ class SamplingBatchInfo:
return ret
def prepare_penalties(self):
self.scaling_penalties = None
self.linear_penalties = None
for penalizer in self.penalizer_orchestrator.penalizers.values():
if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
if penalizer.is_prepared():
self.scaling_penalties = penalizer.cumulated_repetition_penalties
else:
if penalizer.is_prepared():
if self.linear_penalties is None:
bs = self.penalizer_orchestrator.batch.batch_size()
self.linear_penalties = torch.zeros(
(bs, self.vocab_size),
dtype=torch.float32,
device="cuda",
)
self.linear_penalties = penalizer.apply(self.linear_penalties)
def update_regex_vocab_mask(self, batch: ScheduleBatch):
bs, reqs = batch.batch_size(), batch.reqs
device = "cuda"
......
......@@ -180,7 +180,7 @@ class SRTRunner:
tp_size=tp_size,
dtype=get_dtype_str(torch_dtype),
port=port,
mem_fraction_static=0.7,
mem_fraction_static=0.69,
trust_remote_code=False,
is_embedding=not self.is_generation,
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment