Unverified Commit 75ce37f4 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Move sampler into CUDA graph (#1201)


Co-authored-by: default avatarYineng Zhang <me@zhyncs.com>
parent 97589a60
...@@ -41,6 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -41,6 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -299,6 +300,7 @@ class MixtralForCausalLM(nn.Module): ...@@ -299,6 +300,7 @@ class MixtralForCausalLM(nn.Module):
self.model = MixtralModel(config, quant_config=quant_config, prefix="model") self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
def forward( def forward(
self, self,
...@@ -308,9 +310,11 @@ class MixtralForCausalLM(nn.Module): ...@@ -308,9 +310,11 @@ class MixtralForCausalLM(nn.Module):
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -45,6 +45,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -333,6 +334,7 @@ class QuantMixtralForCausalLM(nn.Module): ...@@ -333,6 +334,7 @@ class QuantMixtralForCausalLM(nn.Module):
self.model = MixtralModel(config, quant_config=quant_config) self.model = MixtralModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -343,9 +345,11 @@ class QuantMixtralForCausalLM(nn.Module): ...@@ -343,9 +345,11 @@ class QuantMixtralForCausalLM(nn.Module):
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -39,6 +39,7 @@ from sglang.srt.layers.activation import SiluAndMul ...@@ -39,6 +39,7 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -251,6 +252,7 @@ class QWenLMHeadModel(nn.Module): ...@@ -251,6 +252,7 @@ class QWenLMHeadModel(nn.Module):
vocab_size = ((config.vocab_size + 63) // 64) * 64 vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -260,10 +262,11 @@ class QWenLMHeadModel(nn.Module): ...@@ -260,10 +262,11 @@ class QWenLMHeadModel(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
): ):
hidden_states = self.transformer(input_ids, positions, input_metadata) hidden_states = self.transformer(input_ids, positions, input_metadata)
next_tokens = self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
return next_tokens sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -38,8 +38,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -38,8 +38,9 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
Qwen2Config = None Qwen2Config = None
...@@ -276,6 +277,7 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -276,6 +277,7 @@ class Qwen2ForCausalLM(nn.Module):
self.model = Qwen2Model(config, quant_config=quant_config) self.model = Qwen2Model(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@torch.no_grad() @torch.no_grad()
...@@ -289,9 +291,11 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -289,9 +291,11 @@ class Qwen2ForCausalLM(nn.Module):
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
if not get_embedding: if not get_embedding:
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
else: else:
return self.pooler(hidden_states, input_metadata) return self.pooler(hidden_states, input_metadata)
......
...@@ -35,10 +35,8 @@ from vllm.model_executor.layers.linear import ( ...@@ -35,10 +35,8 @@ from vllm.model_executor.layers.linear import (
ReplicatedLinear, ReplicatedLinear,
RowParallelLinear, RowParallelLinear,
) )
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, ParallelLMHead,
VocabParallelEmbedding, VocabParallelEmbedding,
...@@ -49,6 +47,7 @@ from sglang.srt.layers.activation import SiluAndMul ...@@ -49,6 +47,7 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -366,6 +365,7 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -366,6 +365,7 @@ class Qwen2MoeForCausalLM(nn.Module):
config.vocab_size, config.hidden_size, quant_config=quant_config config.vocab_size, config.hidden_size, quant_config=quant_config
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -376,20 +376,11 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -376,20 +376,11 @@ class Qwen2MoeForCausalLM(nn.Module):
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
def compute_logits( return sample_output, logits_output
self,
input_ids: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
logits = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
return logits
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -40,6 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -40,6 +40,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -249,6 +250,7 @@ class StableLmForCausalLM(nn.Module): ...@@ -249,6 +250,7 @@ class StableLmForCausalLM(nn.Module):
self.model = StableLMEpochModel(config, quant_config=quant_config) self.model = StableLMEpochModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -259,9 +261,11 @@ class StableLmForCausalLM(nn.Module): ...@@ -259,9 +261,11 @@ class StableLmForCausalLM(nn.Module):
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor( logits_output = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -21,10 +21,63 @@ class SamplingBatchInfo: ...@@ -21,10 +21,63 @@ class SamplingBatchInfo:
top_ps: torch.Tensor = None top_ps: torch.Tensor = None
top_ks: torch.Tensor = None top_ks: torch.Tensor = None
min_ps: torch.Tensor = None min_ps: torch.Tensor = None
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
# Dispatch in CUDA graph
need_min_p_sampling: bool = False
# Bias Tensors
logit_bias: torch.Tensor = None logit_bias: torch.Tensor = None
vocab_mask: torch.Tensor = None vocab_mask: torch.Tensor = None
# Penalizer
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
linear_penalties: torch.Tensor = None
scaling_penalties: torch.Tensor = None
def has_bias(self):
return (
self.logit_bias is not None
or self.vocab_mask is not None
or self.linear_penalties is not None
or self.scaling_penalties is not None
)
@classmethod
def dummy_one(cls, max_bs: int, vocab_size: int):
ret = cls(vocab_size=vocab_size)
ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float, device="cuda")
ret.top_ps = torch.ones((max_bs,), dtype=torch.float, device="cuda")
ret.top_ks = torch.ones((max_bs,), dtype=torch.int, device="cuda")
ret.min_ps = torch.zeros((max_bs,), dtype=torch.float, device="cuda")
return ret
def __getitem__(self, key):
if isinstance(key, slice):
# NOTE: We do not use cuda graph when there is bias tensors
assert not self.has_bias()
return SamplingBatchInfo(
vocab_size=self.vocab_size,
temperatures=self.temperatures[key],
top_ps=self.top_ps[key],
top_ks=self.top_ks[key],
min_ps=self.min_ps[key],
need_min_p_sampling=self.need_min_p_sampling,
)
else:
raise NotImplementedError
def inplace_assign(self, bs: int, other: SamplingBatchInfo):
# NOTE: We do not use cuda graph when there is bias tensors
assert not self.has_bias()
self.vocab_size = other.vocab_size
self.need_min_p_sampling = other.need_min_p_sampling
self.temperatures[:bs] = other.temperatures
self.top_ps[:bs] = other.top_ps
self.top_ks[:bs] = other.top_ks
self.min_ps[:bs] = other.min_ps
@classmethod @classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int): def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
device = "cuda" device = "cuda"
...@@ -45,6 +98,7 @@ class SamplingBatchInfo: ...@@ -45,6 +98,7 @@ class SamplingBatchInfo:
ret.min_ps = torch.tensor( ret.min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device [r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
) )
ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
# Each penalizers will do nothing if they evaluate themselves as not required by looking at # Each penalizers will do nothing if they evaluate themselves as not required by looking at
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this # the sampling_params of the requests (See {_is_required()} of each penalizers). So this
...@@ -72,6 +126,25 @@ class SamplingBatchInfo: ...@@ -72,6 +126,25 @@ class SamplingBatchInfo:
return ret return ret
def prepare_penalties(self):
self.scaling_penalties = None
self.linear_penalties = None
for penalizer in self.penalizer_orchestrator.penalizers.values():
if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
if penalizer.is_prepared():
self.scaling_penalties = penalizer.cumulated_repetition_penalties
else:
if penalizer.is_prepared():
if self.linear_penalties is None:
bs = self.penalizer_orchestrator.batch.batch_size()
self.linear_penalties = torch.zeros(
(bs, self.vocab_size),
dtype=torch.float32,
device="cuda",
)
self.linear_penalties = penalizer.apply(self.linear_penalties)
def update_regex_vocab_mask(self, batch: ScheduleBatch): def update_regex_vocab_mask(self, batch: ScheduleBatch):
bs, reqs = batch.batch_size(), batch.reqs bs, reqs = batch.batch_size(), batch.reqs
device = "cuda" device = "cuda"
......
...@@ -180,7 +180,7 @@ class SRTRunner: ...@@ -180,7 +180,7 @@ class SRTRunner:
tp_size=tp_size, tp_size=tp_size,
dtype=get_dtype_str(torch_dtype), dtype=get_dtype_str(torch_dtype),
port=port, port=port,
mem_fraction_static=0.7, mem_fraction_static=0.69,
trust_remote_code=False, trust_remote_code=False,
is_embedding=not self.is_generation, is_embedding=not self.is_generation,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment