Unverified Commit 70b68029 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Optimize conflicts between CUDA graph and vocab mask tensors (#1392)

parent f3d32f88
...@@ -207,15 +207,15 @@ def extend(reqs, model_runner): ...@@ -207,15 +207,15 @@ def extend(reqs, model_runner):
tree_cache=None, tree_cache=None,
) )
batch.prepare_for_extend(model_runner.model_config.vocab_size) batch.prepare_for_extend(model_runner.model_config.vocab_size)
sample_output, logits_output = model_runner.forward(batch) logits_output = model_runner.forward(batch)
next_token_ids = sample_output.batch_next_token_ids.tolist() next_token_ids = model_runner.sample(logits_output, batch).tolist()
return next_token_ids, logits_output.next_token_logits, batch return next_token_ids, logits_output.next_token_logits, batch
def decode(input_token_ids, batch, model_runner): def decode(input_token_ids, batch, model_runner):
batch.prepare_for_decode(input_token_ids) batch.prepare_for_decode(input_token_ids)
sample_output, logits_output = model_runner.forward(batch) logits_output = model_runner.forward(batch)
next_token_ids = sample_output.batch_next_token_ids.tolist() next_token_ids = model_runner.sample(logits_output, batch).tolist()
return next_token_ids, logits_output.next_token_logits return next_token_ids, logits_output.next_token_logits
......
...@@ -35,21 +35,6 @@ class Sampler(CustomOp): ...@@ -35,21 +35,6 @@ class Sampler(CustomOp):
self.forward_native = self.forward_cuda self.forward_native = self.forward_cuda
self.is_torch_compile = False self.is_torch_compile = False
def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
# min-token, presence, frequency
if sampling_info.linear_penalties is not None:
logits += sampling_info.linear_penalties
# repetition
if sampling_info.scaling_penalties is not None:
logits = torch.where(
logits > 0,
logits / sampling_info.scaling_penalties,
logits * sampling_info.scaling_penalties,
)
return logits
def _get_probs(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo): def _get_probs(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
# Post process logits # Post process logits
logits = logits.contiguous() logits = logits.contiguous()
...@@ -58,14 +43,6 @@ class Sampler(CustomOp): ...@@ -58,14 +43,6 @@ class Sampler(CustomOp):
# FIXME: Temporary workaround for unknown bugs in torch.compile # FIXME: Temporary workaround for unknown bugs in torch.compile
logits.add_(0) logits.add_(0)
if sampling_info.logit_bias is not None:
logits.add_(sampling_info.logit_bias)
if sampling_info.vocab_mask is not None:
logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
logits = self._apply_penalties(logits, sampling_info)
return torch.softmax(logits, dim=-1) return torch.softmax(logits, dim=-1)
def forward_cuda( def forward_cuda(
......
...@@ -33,10 +33,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode ...@@ -33,10 +33,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
if TYPE_CHECKING:
from sglang.srt.layers.sampler import SampleOutput
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5 INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
# Put some global args for easy access # Put some global args for easy access
...@@ -710,18 +706,3 @@ class ScheduleBatch: ...@@ -710,18 +706,3 @@ class ScheduleBatch:
self.out_cache_loc = None self.out_cache_loc = None
self.top_logprobs_nums.extend(other.top_logprobs_nums) self.top_logprobs_nums.extend(other.top_logprobs_nums)
self.return_logprob = any(req.return_logprob for req in self.reqs) self.return_logprob = any(req.return_logprob for req in self.reqs)
def check_sample_results(self, sample_output: SampleOutput):
if not torch.all(sample_output.success):
probs = sample_output.probs
batch_next_token_ids = sample_output.batch_next_token_ids
logging.warning("Sampling failed, fallback to top_k=1 strategy")
probs = probs.masked_fill(torch.isnan(probs), 0.0)
argmax_ids = torch.argmax(probs, dim=-1)
batch_next_token_ids = torch.where(
sample_output.success, batch_next_token_ids, argmax_ids
)
sample_output.probs = probs
sample_output.batch_next_token_ids = batch_next_token_ids
return sample_output.batch_next_token_ids
...@@ -547,8 +547,9 @@ class ModelTpServer: ...@@ -547,8 +547,9 @@ class ModelTpServer:
if self.model_runner.is_generation: if self.model_runner.is_generation:
# Forward and sample the next tokens # Forward and sample the next tokens
if batch.extend_num_tokens != 0: if batch.extend_num_tokens != 0:
sample_output, logits_output = self.model_runner.forward(batch) logits_output = self.model_runner.forward(batch)
next_token_ids = batch.check_sample_results(sample_output) next_token_ids = self.model_runner.sample(logits_output, batch)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids next_token_ids
) )
...@@ -723,8 +724,8 @@ class ModelTpServer: ...@@ -723,8 +724,8 @@ class ModelTpServer:
batch.prepare_for_decode() batch.prepare_for_decode()
# Forward and sample the next tokens # Forward and sample the next tokens
sample_output, logits_output = self.model_runner.forward(batch) logits_output = self.model_runner.forward(batch)
next_token_ids = batch.check_sample_results(sample_output) next_token_ids = self.model_runner.sample(logits_output, batch)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens( batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids next_token_ids
) )
......
...@@ -30,10 +30,8 @@ from sglang.srt.layers.logits_processor import ( ...@@ -30,10 +30,8 @@ from sglang.srt.layers.logits_processor import (
LogitsProcessor, LogitsProcessor,
LogitsProcessorOutput, LogitsProcessorOutput,
) )
from sglang.srt.layers.sampler import SampleOutput
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import monkey_patch_vllm_all_gather from sglang.srt.utils import monkey_patch_vllm_all_gather
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -129,10 +127,6 @@ class CudaGraphRunner: ...@@ -129,10 +127,6 @@ class CudaGraphRunner:
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
) )
# Sampling info
vocab_size = model_runner.model_config.vocab_size
self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
if self.use_torch_compile: if self.use_torch_compile:
set_torch_compile_config() set_torch_compile_config()
...@@ -191,7 +185,6 @@ class CudaGraphRunner: ...@@ -191,7 +185,6 @@ class CudaGraphRunner:
def run_once(): def run_once():
input_metadata = InputMetadata( input_metadata = InputMetadata(
forward_mode=ForwardMode.DECODE, forward_mode=ForwardMode.DECODE,
sampling_info=self.sampling_info[:bs],
batch_size=bs, batch_size=bs,
req_pool_indices=req_pool_indices, req_pool_indices=req_pool_indices,
seq_lens=seq_lens, seq_lens=seq_lens,
...@@ -250,14 +243,9 @@ class CudaGraphRunner: ...@@ -250,14 +243,9 @@ class CudaGraphRunner:
bs, self.req_pool_indices, self.seq_lens bs, self.req_pool_indices, self.seq_lens
) )
# Sampling inputs
self.sampling_info.inplace_assign(raw_bs, batch.sampling_info)
# Replay # Replay
torch.cuda.synchronize()
self.graphs[bs].replay() self.graphs[bs].replay()
torch.cuda.synchronize() logits_output = self.output_buffers[bs]
sample_output, logits_output = self.output_buffers[bs]
# Unpad # Unpad
if bs != raw_bs: if bs != raw_bs:
...@@ -269,11 +257,6 @@ class CudaGraphRunner: ...@@ -269,11 +257,6 @@ class CudaGraphRunner:
input_top_logprobs=None, input_top_logprobs=None,
output_top_logprobs=None, output_top_logprobs=None,
) )
sample_output = SampleOutput(
sample_output.success[:raw_bs],
sample_output.probs[:raw_bs],
sample_output.batch_next_token_ids[:raw_bs],
)
# Extract logprobs # Extract logprobs
if batch.return_logprob: if batch.return_logprob:
...@@ -290,4 +273,4 @@ class CudaGraphRunner: ...@@ -290,4 +273,4 @@ class CudaGraphRunner:
logits_output.next_token_logprobs, logits_metadata logits_output.next_token_logprobs, logits_metadata
)[1] )[1]
return sample_output, logits_output return logits_output
...@@ -28,7 +28,6 @@ if TYPE_CHECKING: ...@@ -28,7 +28,6 @@ if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
class ForwardMode(IntEnum): class ForwardMode(IntEnum):
...@@ -59,7 +58,6 @@ class InputMetadata: ...@@ -59,7 +58,6 @@ class InputMetadata:
"""Store all inforamtion of a forward pass.""" """Store all inforamtion of a forward pass."""
forward_mode: ForwardMode forward_mode: ForwardMode
sampling_info: SamplingBatchInfo
batch_size: int batch_size: int
req_pool_indices: torch.Tensor req_pool_indices: torch.Tensor
seq_lens: torch.Tensor seq_lens: torch.Tensor
...@@ -170,7 +168,6 @@ class InputMetadata: ...@@ -170,7 +168,6 @@ class InputMetadata:
): ):
ret = cls( ret = cls(
forward_mode=batch.forward_mode, forward_mode=batch.forward_mode,
sampling_info=batch.sampling_info,
batch_size=batch.batch_size(), batch_size=batch.batch_size(),
req_pool_indices=batch.req_pool_indices, req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens, seq_lens=batch.seq_lens,
...@@ -182,8 +179,6 @@ class InputMetadata: ...@@ -182,8 +179,6 @@ class InputMetadata:
top_logprobs_nums=batch.top_logprobs_nums, top_logprobs_nums=batch.top_logprobs_nums,
) )
ret.sampling_info.update_penalties()
ret.sampling_info.update_regex_vocab_mask(batch)
ret.compute_positions(batch) ret.compute_positions(batch)
if not batch.forward_mode.is_decode(): if not batch.forward_mode.is_decode():
......
...@@ -40,7 +40,7 @@ from vllm.model_executor.models import ModelRegistry ...@@ -40,7 +40,7 @@ from vllm.model_executor.models import ModelRegistry
from sglang.srt.configs.model_config import AttentionArch, ModelConfig from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import SampleOutput from sglang.srt.layers.sampler import SampleOutput, Sampler
from sglang.srt.lora.lora_manager import LoRAManager from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.mem_cache.memory_pool import ( from sglang.srt.mem_cache.memory_pool import (
...@@ -49,6 +49,7 @@ from sglang.srt.mem_cache.memory_pool import ( ...@@ -49,6 +49,7 @@ from sglang.srt.mem_cache.memory_pool import (
ReqToTokenPool, ReqToTokenPool,
) )
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
get_available_gpu_memory, get_available_gpu_memory,
...@@ -107,6 +108,7 @@ class ModelRunner: ...@@ -107,6 +108,7 @@ class ModelRunner:
# Init componnets # Init componnets
min_per_gpu_memory = self.init_torch_distributed() min_per_gpu_memory = self.init_torch_distributed()
self.sampler = Sampler()
self.load_model() self.load_model()
if server_args.lora_paths is not None: if server_args.lora_paths is not None:
self.init_lora_manager() self.init_lora_manager()
...@@ -466,11 +468,8 @@ class ModelRunner: ...@@ -466,11 +468,8 @@ class ModelRunner:
def forward_decode(self, batch: ScheduleBatch): def forward_decode(self, batch: ScheduleBatch):
if self.server_args.lora_paths is not None: if self.server_args.lora_paths is not None:
self.lora_manager.prepare_lora_batch(batch) self.lora_manager.prepare_lora_batch(batch)
if (
self.cuda_graph_runner if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
and self.cuda_graph_runner.can_run(len(batch.reqs))
and batch.sampling_info.can_run_in_cuda_graph()
):
return self.cuda_graph_runner.replay(batch) return self.cuda_graph_runner.replay(batch)
input_metadata = InputMetadata.from_schedule_batch(self, batch) input_metadata = InputMetadata.from_schedule_batch(self, batch)
...@@ -510,9 +509,7 @@ class ModelRunner: ...@@ -510,9 +509,7 @@ class ModelRunner:
input_metadata.image_offsets, input_metadata.image_offsets,
) )
def forward( def forward(self, batch: ScheduleBatch) -> Tuple[LogitsProcessorOutput]:
self, batch: ScheduleBatch
) -> Tuple[SampleOutput, LogitsProcessorOutput]:
assert batch.forward_mode is not None assert batch.forward_mode is not None
if self.is_multimodal_model and batch.forward_mode.is_extend(): if self.is_multimodal_model and batch.forward_mode.is_extend():
...@@ -524,6 +521,57 @@ class ModelRunner: ...@@ -524,6 +521,57 @@ class ModelRunner:
else: else:
raise ValueError(f"Invaid forward mode: {batch.forward_mode}") raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
def _check_sample_results(self, sample_output: SampleOutput):
if not torch.all(sample_output.success):
probs = sample_output.probs
batch_next_token_ids = sample_output.batch_next_token_ids
logging.warning("Sampling failed, fallback to top_k=1 strategy")
probs = probs.masked_fill(torch.isnan(probs), 0.0)
argmax_ids = torch.argmax(probs, dim=-1)
batch_next_token_ids = torch.where(
sample_output.success, batch_next_token_ids, argmax_ids
)
sample_output.probs = probs
sample_output.batch_next_token_ids = batch_next_token_ids
return sample_output.batch_next_token_ids
def _apply_logits_bias(
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
):
# Apply logit_bias
if sampling_info.logit_bias is not None:
logits.add_(sampling_info.logit_bias)
# min-token, presence, frequency
if sampling_info.linear_penalties is not None:
logits += sampling_info.linear_penalties
# repetition
if sampling_info.scaling_penalties is not None:
logits = torch.where(
logits > 0,
logits / sampling_info.scaling_penalties,
logits * sampling_info.scaling_penalties,
)
# Apply regex vocab_mask
if sampling_info.vocab_mask is not None:
logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
return logits
def sample(
self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch
) -> torch.Tensor:
batch.sampling_info.update_regex_vocab_mask(batch)
batch.sampling_info.update_penalties()
logits = self._apply_logits_bias(
logits_output.next_token_logits, batch.sampling_info
)
sample_output = self.sampler(logits, batch.sampling_info)
return self._check_sample_results(sample_output)
@lru_cache() @lru_cache()
def import_model_classes(): def import_model_classes():
......
...@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul ...@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -346,7 +345,6 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -346,7 +345,6 @@ class BaiChuanBaseForCausalLM(nn.Module):
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
def forward( def forward(
self, self,
...@@ -355,12 +353,9 @@ class BaiChuanBaseForCausalLM(nn.Module): ...@@ -355,12 +353,9 @@ class BaiChuanBaseForCausalLM(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata) hidden_states = self.model(input_ids, positions, input_metadata)
logits_output = self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -42,7 +42,6 @@ from sglang.srt.layers.activation import SiluAndMul ...@@ -42,7 +42,6 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
LoraConfig = None LoraConfig = None
...@@ -371,7 +370,6 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -371,7 +370,6 @@ class ChatGLMForCausalLM(nn.Module):
self.transformer = ChatGLMModel(config, cache_config, quant_config) self.transformer = ChatGLMModel(config, cache_config, quant_config)
self.lm_head = self.transformer.output_layer self.lm_head = self.transformer.output_layer
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -381,11 +379,9 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -381,11 +379,9 @@ class ChatGLMForCausalLM(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, input_metadata) hidden_states = self.transformer(input_ids, positions, input_metadata)
logits_output = self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
......
...@@ -64,7 +64,6 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -64,7 +64,6 @@ from vllm.model_executor.utils import set_weight_attrs
from sglang.srt.layers.activation import SiluAndMul from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -327,7 +326,6 @@ class CohereForCausalLM(nn.Module): ...@@ -327,7 +326,6 @@ class CohereForCausalLM(nn.Module):
self.config = config self.config = config
self.quant_config = quant_config self.quant_config = quant_config
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
self.model = CohereModel(config, quant_config) self.model = CohereModel(config, quant_config)
@torch.no_grad() @torch.no_grad()
...@@ -342,11 +340,9 @@ class CohereForCausalLM(nn.Module): ...@@ -342,11 +340,9 @@ class CohereForCausalLM(nn.Module):
positions, positions,
input_metadata, input_metadata,
) )
logits_output = self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -45,7 +45,6 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig ...@@ -45,7 +45,6 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -383,7 +382,6 @@ class DbrxForCausalLM(nn.Module): ...@@ -383,7 +382,6 @@ class DbrxForCausalLM(nn.Module):
padding_size=DEFAULT_VOCAB_PADDING_SIZE, padding_size=DEFAULT_VOCAB_PADDING_SIZE,
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -393,11 +391,9 @@ class DbrxForCausalLM(nn.Module): ...@@ -393,11 +391,9 @@ class DbrxForCausalLM(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, input_metadata) hidden_states = self.transformer(input_ids, positions, input_metadata)
logits_output = self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
expert_params_mapping = [ expert_params_mapping = [
......
...@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul ...@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -386,7 +385,6 @@ class DeepseekForCausalLM(nn.Module): ...@@ -386,7 +385,6 @@ class DeepseekForCausalLM(nn.Module):
config.vocab_size, config.hidden_size, quant_config=quant_config config.vocab_size, config.hidden_size, quant_config=quant_config
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -396,11 +394,9 @@ class DeepseekForCausalLM(nn.Module): ...@@ -396,11 +394,9 @@ class DeepseekForCausalLM(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata) hidden_states = self.model(input_ids, positions, input_metadata)
logits_output = self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul ...@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -649,7 +648,6 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -649,7 +648,6 @@ class DeepseekV2ForCausalLM(nn.Module):
config.vocab_size, config.hidden_size, quant_config=quant_config config.vocab_size, config.hidden_size, quant_config=quant_config
) )
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
def forward( def forward(
self, self,
...@@ -658,11 +656,9 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -658,11 +656,9 @@ class DeepseekV2ForCausalLM(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata) hidden_states = self.model(input_ids, positions, input_metadata)
logits_output = self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -40,7 +40,6 @@ from sglang.srt.layers.activation import SiluAndMul ...@@ -40,7 +40,6 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -304,7 +303,6 @@ class ExaoneForCausalLM(nn.Module): ...@@ -304,7 +303,6 @@ class ExaoneForCausalLM(nn.Module):
self.transformer = ExaoneModel(config, quant_config=quant_config) self.transformer = ExaoneModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -317,11 +315,9 @@ class ExaoneForCausalLM(nn.Module): ...@@ -317,11 +315,9 @@ class ExaoneForCausalLM(nn.Module):
hidden_states = self.transformer( hidden_states = self.transformer(
input_ids, positions, input_metadata, input_embeds input_ids, positions, input_metadata, input_embeds
) )
logits_output = self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -37,7 +37,6 @@ from sglang.srt.layers.activation import GeluAndMul ...@@ -37,7 +37,6 @@ from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -288,7 +287,6 @@ class GemmaForCausalLM(nn.Module): ...@@ -288,7 +287,6 @@ class GemmaForCausalLM(nn.Module):
self.quant_config = quant_config self.quant_config = quant_config
self.model = GemmaModel(config, quant_config=quant_config) self.model = GemmaModel(config, quant_config=quant_config)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -299,11 +297,9 @@ class GemmaForCausalLM(nn.Module): ...@@ -299,11 +297,9 @@ class GemmaForCausalLM(nn.Module):
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits_output = self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return (sample_output, logits_output)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -37,7 +37,6 @@ from sglang.srt.layers.activation import GeluAndMul ...@@ -37,7 +37,6 @@ from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.layernorm import GemmaRMSNorm from sglang.srt.layers.layernorm import GemmaRMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -347,7 +346,6 @@ class Gemma2ForCausalLM(nn.Module): ...@@ -347,7 +346,6 @@ class Gemma2ForCausalLM(nn.Module):
self.quant_config = quant_config self.quant_config = quant_config
self.model = Gemma2Model(config, cache_config, quant_config) self.model = Gemma2Model(config, cache_config, quant_config)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -358,11 +356,9 @@ class Gemma2ForCausalLM(nn.Module): ...@@ -358,11 +356,9 @@ class Gemma2ForCausalLM(nn.Module):
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits_output = self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def get_attention_sliding_window_size(self): def get_attention_sliding_window_size(self):
return get_attention_sliding_window_size(self.config) return get_attention_sliding_window_size(self.config)
......
...@@ -35,7 +35,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -35,7 +35,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import get_act_fn from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -262,7 +261,6 @@ class GPTBigCodeForCausalLM(nn.Module): ...@@ -262,7 +261,6 @@ class GPTBigCodeForCausalLM(nn.Module):
if lora_config: if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -272,11 +270,9 @@ class GPTBigCodeForCausalLM(nn.Module): ...@@ -272,11 +270,9 @@ class GPTBigCodeForCausalLM(nn.Module):
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, input_metadata) hidden_states = self.transformer(input_ids, positions, input_metadata)
logits_output = self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
......
...@@ -46,7 +46,6 @@ from sglang.srt.layers.fused_moe import FusedMoE ...@@ -46,7 +46,6 @@ from sglang.srt.layers.fused_moe import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -298,7 +297,6 @@ class Grok1ForCausalLM(nn.Module): ...@@ -298,7 +297,6 @@ class Grok1ForCausalLM(nn.Module):
self.model = Grok1Model(config, quant_config=quant_config) self.model = Grok1Model(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
# Monkey patch _prepare_weights to load pre-sharded weights # Monkey patch _prepare_weights to load pre-sharded weights
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights) setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
...@@ -315,11 +313,9 @@ class Grok1ForCausalLM(nn.Module): ...@@ -315,11 +313,9 @@ class Grok1ForCausalLM(nn.Module):
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits_output = self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -40,7 +40,6 @@ from sglang.srt.layers.activation import SiluAndMul ...@@ -40,7 +40,6 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -263,7 +262,6 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -263,7 +262,6 @@ class InternLM2ForCausalLM(nn.Module):
self.model = InternLM2Model(config, quant_config) self.model = InternLM2Model(config, quant_config)
self.output = ParallelLMHead(config.vocab_size, config.hidden_size) self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad() @torch.no_grad()
def forward( def forward(
...@@ -274,11 +272,9 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -274,11 +272,9 @@ class InternLM2ForCausalLM(nn.Module):
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits_output = self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.output.weight, input_metadata input_ids, hidden_states, self.output.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
...@@ -41,7 +41,6 @@ from sglang.srt.layers.activation import SiluAndMul ...@@ -41,7 +41,6 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import torchao_quantize_param_data from sglang.srt.layers.torchao_utils import torchao_quantize_param_data
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import InputMetadata
...@@ -305,7 +304,6 @@ class LlamaForCausalLM(nn.Module): ...@@ -305,7 +304,6 @@ class LlamaForCausalLM(nn.Module):
self.model = LlamaModel(config, quant_config=quant_config) self.model = LlamaModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config) self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
self.param_dict = dict(self.named_parameters()) self.param_dict = dict(self.named_parameters())
...@@ -318,11 +316,9 @@ class LlamaForCausalLM(nn.Module): ...@@ -318,11 +316,9 @@ class LlamaForCausalLM(nn.Module):
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> LogitsProcessorOutput: ) -> LogitsProcessorOutput:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits_output = self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, input_metadata
) )
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def get_hidden_dim(self, module_name): def get_hidden_dim(self, module_name):
if module_name in ["q_proj", "o_proj", "qkv_proj"]: if module_name in ["q_proj", "o_proj", "qkv_proj"]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment