Unverified Commit f25f4dfd authored by Yineng Zhang's avatar Yineng Zhang Committed by GitHub
Browse files

hotfix: revert sampler CUDA Graph (#1242)

parent 184ae1c6
......@@ -40,7 +40,6 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
......@@ -263,7 +262,6 @@ class InternLM2ForCausalLM(nn.Module):
self.model = InternLM2Model(config, quant_config)
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
......@@ -274,11 +272,9 @@ class InternLM2ForCausalLM(nn.Module):
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.output.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......
......@@ -39,9 +39,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.logits_processor import LogitProcessorOutput, LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
......@@ -303,7 +302,6 @@ class LlamaForCausalLM(nn.Module):
self.model = LlamaModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
......@@ -312,13 +310,11 @@ class LlamaForCausalLM(nn.Module):
positions: torch.Tensor,
input_metadata: InputMetadata,
input_embeds: torch.Tensor = None,
) -> LogitsProcessorOutput:
) -> LogitProcessorOutput:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def get_module_name(self, name):
stacked_params_mapping = [
......
......@@ -24,7 +24,7 @@ from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.logits_processor import LogitProcessorOutput
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.models.llama2 import LlamaModel
......@@ -65,7 +65,7 @@ class LlamaForClassification(nn.Module):
(input_metadata.batch_size, self.config.classification_out_size)
).to(input_ids.device)
return LogitsProcessorOutput(
return LogitProcessorOutput(
next_token_logits=scores,
next_token_logprobs=scores,
normalized_prompt_logprobs=scores,
......
......@@ -39,7 +39,6 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
......@@ -298,7 +297,6 @@ class MiniCPMForCausalLM(nn.Module):
self.scale_width = self.config.hidden_size / self.config.dim_model_base
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
......@@ -316,11 +314,9 @@ class MiniCPMForCausalLM(nn.Module):
lm_head_weight = self.model.embed_tokens.weight
else:
lm_head_weight = self.lm_head.weight
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, lm_head_weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......
......@@ -41,7 +41,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
......@@ -300,7 +299,6 @@ class MixtralForCausalLM(nn.Module):
self.model = MixtralModel(config, quant_config=quant_config, prefix="model")
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
def forward(
self,
......@@ -310,11 +308,9 @@ class MixtralForCausalLM(nn.Module):
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......
......@@ -45,7 +45,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
......@@ -334,7 +333,6 @@ class QuantMixtralForCausalLM(nn.Module):
self.model = MixtralModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
......@@ -345,11 +343,9 @@ class QuantMixtralForCausalLM(nn.Module):
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......
......@@ -39,7 +39,6 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
......@@ -252,7 +251,6 @@ class QWenLMHeadModel(nn.Module):
vocab_size = ((config.vocab_size + 63) // 64) * 64
self.lm_head = ParallelLMHead(vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
......@@ -262,11 +260,10 @@ class QWenLMHeadModel(nn.Module):
input_metadata: InputMetadata,
):
hidden_states = self.transformer(input_ids, positions, input_metadata)
logits_output = self.logits_processor(
next_tokens = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
return next_tokens
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......
......@@ -38,9 +38,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
Qwen2Config = None
......@@ -277,7 +276,6 @@ class Qwen2ForCausalLM(nn.Module):
self.model = Qwen2Model(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
@torch.no_grad()
......@@ -291,11 +289,9 @@ class Qwen2ForCausalLM(nn.Module):
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
if not get_embedding:
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
else:
return self.pooler(hidden_states, input_metadata)
......
......@@ -35,8 +35,10 @@ from vllm.model_executor.layers.linear import (
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
......@@ -47,7 +49,6 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
......@@ -365,7 +366,6 @@ class Qwen2MoeForCausalLM(nn.Module):
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
......@@ -376,11 +376,20 @@ class Qwen2MoeForCausalLM(nn.Module):
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def compute_logits(
self,
input_ids: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
logits = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
return logits
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......
......@@ -40,7 +40,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
......@@ -250,7 +249,6 @@ class StableLmForCausalLM(nn.Module):
self.model = StableLMEpochModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
......@@ -261,11 +259,9 @@ class StableLmForCausalLM(nn.Module):
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......
......@@ -21,63 +21,10 @@ class SamplingBatchInfo:
top_ps: torch.Tensor = None
top_ks: torch.Tensor = None
min_ps: torch.Tensor = None
# Dispatch in CUDA graph
need_min_p_sampling: bool = False
# Bias Tensors
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
logit_bias: torch.Tensor = None
vocab_mask: torch.Tensor = None
# Penalizer
penalizer_orchestrator: penaltylib.BatchedPenalizerOrchestrator = None
linear_penalties: torch.Tensor = None
scaling_penalties: torch.Tensor = None
def has_bias(self):
return (
self.logit_bias is not None
or self.vocab_mask is not None
or self.linear_penalties is not None
or self.scaling_penalties is not None
)
@classmethod
def dummy_one(cls, max_bs: int, vocab_size: int):
ret = cls(vocab_size=vocab_size)
ret.temperatures = torch.ones((max_bs, 1), dtype=torch.float, device="cuda")
ret.top_ps = torch.ones((max_bs,), dtype=torch.float, device="cuda")
ret.top_ks = torch.ones((max_bs,), dtype=torch.int, device="cuda")
ret.min_ps = torch.zeros((max_bs,), dtype=torch.float, device="cuda")
return ret
def __getitem__(self, key):
if isinstance(key, slice):
# NOTE: We do not use cuda graph when there is bias tensors
assert not self.has_bias()
return SamplingBatchInfo(
vocab_size=self.vocab_size,
temperatures=self.temperatures[key],
top_ps=self.top_ps[key],
top_ks=self.top_ks[key],
min_ps=self.min_ps[key],
need_min_p_sampling=self.need_min_p_sampling,
)
else:
raise NotImplementedError
def inplace_assign(self, bs: int, other: SamplingBatchInfo):
# NOTE: We do not use cuda graph when there is bias tensors
assert not self.has_bias()
self.vocab_size = other.vocab_size
self.need_min_p_sampling = other.need_min_p_sampling
self.temperatures[:bs] = other.temperatures
self.top_ps[:bs] = other.top_ps
self.top_ks[:bs] = other.top_ks
self.min_ps[:bs] = other.min_ps
@classmethod
def from_schedule_batch(cls, batch: ScheduleBatch, vocab_size: int):
device = "cuda"
......@@ -98,7 +45,6 @@ class SamplingBatchInfo:
ret.min_ps = torch.tensor(
[r.sampling_params.min_p for r in reqs], dtype=torch.float, device=device
)
ret.need_min_p_sampling = any(r.sampling_params.min_p > 0 for r in reqs)
# Each penalizers will do nothing if they evaluate themselves as not required by looking at
# the sampling_params of the requests (See {_is_required()} of each penalizers). So this
......@@ -126,25 +72,6 @@ class SamplingBatchInfo:
return ret
def prepare_penalties(self):
self.scaling_penalties = None
self.linear_penalties = None
for penalizer in self.penalizer_orchestrator.penalizers.values():
if isinstance(penalizer, penaltylib.BatchedRepetitionPenalizer):
if penalizer.is_prepared():
self.scaling_penalties = penalizer.cumulated_repetition_penalties
else:
if penalizer.is_prepared():
if self.linear_penalties is None:
bs = self.penalizer_orchestrator.batch.batch_size()
self.linear_penalties = torch.zeros(
(bs, self.vocab_size),
dtype=torch.float32,
device="cuda",
)
self.linear_penalties = penalizer.apply(self.linear_penalties)
def update_regex_vocab_mask(self, batch: ScheduleBatch):
bs, reqs = batch.batch_size(), batch.reqs
device = "cuda"
......
......@@ -180,7 +180,7 @@ class SRTRunner:
tp_size=tp_size,
dtype=get_dtype_str(torch_dtype),
port=port,
mem_fraction_static=0.69,
mem_fraction_static=0.7,
trust_remote_code=False,
is_embedding=not self.is_generation,
)
......
__version__ = "0.2.14"
__version__ = "0.2.14.post1"
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment