Unverified Commit 70b68029 authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Optimize conflicts between CUDA graph and vocab mask tensors (#1392)

parent f3d32f88
......@@ -207,15 +207,15 @@ def extend(reqs, model_runner):
tree_cache=None,
)
batch.prepare_for_extend(model_runner.model_config.vocab_size)
sample_output, logits_output = model_runner.forward(batch)
next_token_ids = sample_output.batch_next_token_ids.tolist()
logits_output = model_runner.forward(batch)
next_token_ids = model_runner.sample(logits_output, batch).tolist()
return next_token_ids, logits_output.next_token_logits, batch
def decode(input_token_ids, batch, model_runner):
batch.prepare_for_decode(input_token_ids)
sample_output, logits_output = model_runner.forward(batch)
next_token_ids = sample_output.batch_next_token_ids.tolist()
logits_output = model_runner.forward(batch)
next_token_ids = model_runner.sample(logits_output, batch).tolist()
return next_token_ids, logits_output.next_token_logits
......
......@@ -35,21 +35,6 @@ class Sampler(CustomOp):
self.forward_native = self.forward_cuda
self.is_torch_compile = False
def _apply_penalties(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
# min-token, presence, frequency
if sampling_info.linear_penalties is not None:
logits += sampling_info.linear_penalties
# repetition
if sampling_info.scaling_penalties is not None:
logits = torch.where(
logits > 0,
logits / sampling_info.scaling_penalties,
logits * sampling_info.scaling_penalties,
)
return logits
def _get_probs(self, logits: torch.Tensor, sampling_info: SamplingBatchInfo):
# Post process logits
logits = logits.contiguous()
......@@ -58,14 +43,6 @@ class Sampler(CustomOp):
# FIXME: Temporary workaround for unknown bugs in torch.compile
logits.add_(0)
if sampling_info.logit_bias is not None:
logits.add_(sampling_info.logit_bias)
if sampling_info.vocab_mask is not None:
logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
logits = self._apply_penalties(logits, sampling_info)
return torch.softmax(logits, dim=-1)
def forward_cuda(
......
......@@ -33,10 +33,6 @@ from sglang.srt.model_executor.forward_batch_info import ForwardMode
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs
if TYPE_CHECKING:
from sglang.srt.layers.sampler import SampleOutput
INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
# Put some global args for easy access
......@@ -710,18 +706,3 @@ class ScheduleBatch:
self.out_cache_loc = None
self.top_logprobs_nums.extend(other.top_logprobs_nums)
self.return_logprob = any(req.return_logprob for req in self.reqs)
def check_sample_results(self, sample_output: SampleOutput):
if not torch.all(sample_output.success):
probs = sample_output.probs
batch_next_token_ids = sample_output.batch_next_token_ids
logging.warning("Sampling failed, fallback to top_k=1 strategy")
probs = probs.masked_fill(torch.isnan(probs), 0.0)
argmax_ids = torch.argmax(probs, dim=-1)
batch_next_token_ids = torch.where(
sample_output.success, batch_next_token_ids, argmax_ids
)
sample_output.probs = probs
sample_output.batch_next_token_ids = batch_next_token_ids
return sample_output.batch_next_token_ids
......@@ -547,8 +547,9 @@ class ModelTpServer:
if self.model_runner.is_generation:
# Forward and sample the next tokens
if batch.extend_num_tokens != 0:
sample_output, logits_output = self.model_runner.forward(batch)
next_token_ids = batch.check_sample_results(sample_output)
logits_output = self.model_runner.forward(batch)
next_token_ids = self.model_runner.sample(logits_output, batch)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
)
......@@ -723,8 +724,8 @@ class ModelTpServer:
batch.prepare_for_decode()
# Forward and sample the next tokens
sample_output, logits_output = self.model_runner.forward(batch)
next_token_ids = batch.check_sample_results(sample_output)
logits_output = self.model_runner.forward(batch)
next_token_ids = self.model_runner.sample(logits_output, batch)
batch.sampling_info.penalizer_orchestrator.cumulate_output_tokens(
next_token_ids
)
......
......@@ -30,10 +30,8 @@ from sglang.srt.layers.logits_processor import (
LogitsProcessor,
LogitsProcessorOutput,
)
from sglang.srt.layers.sampler import SampleOutput
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.utils import monkey_patch_vllm_all_gather
if TYPE_CHECKING:
......@@ -129,10 +127,6 @@ class CudaGraphRunner:
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
)
# Sampling info
vocab_size = model_runner.model_config.vocab_size
self.sampling_info = SamplingBatchInfo.dummy_one(self.max_bs, vocab_size)
if self.use_torch_compile:
set_torch_compile_config()
......@@ -191,7 +185,6 @@ class CudaGraphRunner:
def run_once():
input_metadata = InputMetadata(
forward_mode=ForwardMode.DECODE,
sampling_info=self.sampling_info[:bs],
batch_size=bs,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
......@@ -250,14 +243,9 @@ class CudaGraphRunner:
bs, self.req_pool_indices, self.seq_lens
)
# Sampling inputs
self.sampling_info.inplace_assign(raw_bs, batch.sampling_info)
# Replay
torch.cuda.synchronize()
self.graphs[bs].replay()
torch.cuda.synchronize()
sample_output, logits_output = self.output_buffers[bs]
logits_output = self.output_buffers[bs]
# Unpad
if bs != raw_bs:
......@@ -269,11 +257,6 @@ class CudaGraphRunner:
input_top_logprobs=None,
output_top_logprobs=None,
)
sample_output = SampleOutput(
sample_output.success[:raw_bs],
sample_output.probs[:raw_bs],
sample_output.batch_next_token_ids[:raw_bs],
)
# Extract logprobs
if batch.return_logprob:
......@@ -290,4 +273,4 @@ class CudaGraphRunner:
logits_output.next_token_logprobs, logits_metadata
)[1]
return sample_output, logits_output
return logits_output
......@@ -28,7 +28,6 @@ if TYPE_CHECKING:
from sglang.srt.managers.schedule_batch import ScheduleBatch
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
from sglang.srt.model_executor.model_runner import ModelRunner
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
class ForwardMode(IntEnum):
......@@ -59,7 +58,6 @@ class InputMetadata:
"""Store all inforamtion of a forward pass."""
forward_mode: ForwardMode
sampling_info: SamplingBatchInfo
batch_size: int
req_pool_indices: torch.Tensor
seq_lens: torch.Tensor
......@@ -170,7 +168,6 @@ class InputMetadata:
):
ret = cls(
forward_mode=batch.forward_mode,
sampling_info=batch.sampling_info,
batch_size=batch.batch_size(),
req_pool_indices=batch.req_pool_indices,
seq_lens=batch.seq_lens,
......@@ -182,8 +179,6 @@ class InputMetadata:
top_logprobs_nums=batch.top_logprobs_nums,
)
ret.sampling_info.update_penalties()
ret.sampling_info.update_regex_vocab_mask(batch)
ret.compute_positions(batch)
if not batch.forward_mode.is_decode():
......
......@@ -40,7 +40,7 @@ from vllm.model_executor.models import ModelRegistry
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
from sglang.srt.layers.attention_backend import FlashInferAttnBackend, TritonAttnBackend
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.sampler import SampleOutput
from sglang.srt.layers.sampler import SampleOutput, Sampler
from sglang.srt.lora.lora_manager import LoRAManager
from sglang.srt.managers.schedule_batch import ScheduleBatch, global_server_args_dict
from sglang.srt.mem_cache.memory_pool import (
......@@ -49,6 +49,7 @@ from sglang.srt.mem_cache.memory_pool import (
ReqToTokenPool,
)
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
get_available_gpu_memory,
......@@ -107,6 +108,7 @@ class ModelRunner:
# Init componnets
min_per_gpu_memory = self.init_torch_distributed()
self.sampler = Sampler()
self.load_model()
if server_args.lora_paths is not None:
self.init_lora_manager()
......@@ -466,11 +468,8 @@ class ModelRunner:
def forward_decode(self, batch: ScheduleBatch):
if self.server_args.lora_paths is not None:
self.lora_manager.prepare_lora_batch(batch)
if (
self.cuda_graph_runner
and self.cuda_graph_runner.can_run(len(batch.reqs))
and batch.sampling_info.can_run_in_cuda_graph()
):
if self.cuda_graph_runner and self.cuda_graph_runner.can_run(len(batch.reqs)):
return self.cuda_graph_runner.replay(batch)
input_metadata = InputMetadata.from_schedule_batch(self, batch)
......@@ -510,9 +509,7 @@ class ModelRunner:
input_metadata.image_offsets,
)
def forward(
self, batch: ScheduleBatch
) -> Tuple[SampleOutput, LogitsProcessorOutput]:
def forward(self, batch: ScheduleBatch) -> Tuple[LogitsProcessorOutput]:
assert batch.forward_mode is not None
if self.is_multimodal_model and batch.forward_mode.is_extend():
......@@ -524,6 +521,57 @@ class ModelRunner:
else:
raise ValueError(f"Invaid forward mode: {batch.forward_mode}")
def _check_sample_results(self, sample_output: SampleOutput):
if not torch.all(sample_output.success):
probs = sample_output.probs
batch_next_token_ids = sample_output.batch_next_token_ids
logging.warning("Sampling failed, fallback to top_k=1 strategy")
probs = probs.masked_fill(torch.isnan(probs), 0.0)
argmax_ids = torch.argmax(probs, dim=-1)
batch_next_token_ids = torch.where(
sample_output.success, batch_next_token_ids, argmax_ids
)
sample_output.probs = probs
sample_output.batch_next_token_ids = batch_next_token_ids
return sample_output.batch_next_token_ids
def _apply_logits_bias(
self, logits: torch.Tensor, sampling_info: SamplingBatchInfo
):
# Apply logit_bias
if sampling_info.logit_bias is not None:
logits.add_(sampling_info.logit_bias)
# min-token, presence, frequency
if sampling_info.linear_penalties is not None:
logits += sampling_info.linear_penalties
# repetition
if sampling_info.scaling_penalties is not None:
logits = torch.where(
logits > 0,
logits / sampling_info.scaling_penalties,
logits * sampling_info.scaling_penalties,
)
# Apply regex vocab_mask
if sampling_info.vocab_mask is not None:
logits = logits.masked_fill(sampling_info.vocab_mask, float("-inf"))
return logits
def sample(
self, logits_output: LogitsProcessorOutput, batch: ScheduleBatch
) -> torch.Tensor:
batch.sampling_info.update_regex_vocab_mask(batch)
batch.sampling_info.update_penalties()
logits = self._apply_logits_bias(
logits_output.next_token_logits, batch.sampling_info
)
sample_output = self.sampler(logits, batch.sampling_info)
return self._check_sample_results(sample_output)
@lru_cache()
def import_model_classes():
......
......@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
......@@ -346,7 +345,6 @@ class BaiChuanBaseForCausalLM(nn.Module):
if self.config.tie_word_embeddings:
self.lm_head.weight = self.model.embed_tokens.weight
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
def forward(
self,
......@@ -355,12 +353,9 @@ class BaiChuanBaseForCausalLM(nn.Module):
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......
......@@ -42,7 +42,6 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
LoraConfig = None
......@@ -371,7 +370,6 @@ class ChatGLMForCausalLM(nn.Module):
self.transformer = ChatGLMModel(config, cache_config, quant_config)
self.lm_head = self.transformer.output_layer
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
......@@ -381,11 +379,9 @@ class ChatGLMForCausalLM(nn.Module):
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, input_metadata)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
......
......@@ -64,7 +64,6 @@ from vllm.model_executor.utils import set_weight_attrs
from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
......@@ -327,7 +326,6 @@ class CohereForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
self.model = CohereModel(config, quant_config)
@torch.no_grad()
......@@ -342,11 +340,9 @@ class CohereForCausalLM(nn.Module):
positions,
input_metadata,
)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......
......@@ -45,7 +45,6 @@ from vllm.transformers_utils.configs.dbrx import DbrxConfig
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
......@@ -383,7 +382,6 @@ class DbrxForCausalLM(nn.Module):
padding_size=DEFAULT_VOCAB_PADDING_SIZE,
)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
......@@ -393,11 +391,9 @@ class DbrxForCausalLM(nn.Module):
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, input_metadata)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
expert_params_mapping = [
......
......@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
......@@ -386,7 +385,6 @@ class DeepseekForCausalLM(nn.Module):
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
......@@ -396,11 +394,9 @@ class DeepseekForCausalLM(nn.Module):
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......
......@@ -46,7 +46,6 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata
......@@ -649,7 +648,6 @@ class DeepseekV2ForCausalLM(nn.Module):
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
def forward(
self,
......@@ -658,11 +656,9 @@ class DeepseekV2ForCausalLM(nn.Module):
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......
......@@ -40,7 +40,6 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
......@@ -304,7 +303,6 @@ class ExaoneForCausalLM(nn.Module):
self.transformer = ExaoneModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
......@@ -317,11 +315,9 @@ class ExaoneForCausalLM(nn.Module):
hidden_states = self.transformer(
input_ids, positions, input_metadata, input_embeds
)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......
......@@ -37,7 +37,6 @@ from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
......@@ -288,7 +287,6 @@ class GemmaForCausalLM(nn.Module):
self.quant_config = quant_config
self.model = GemmaModel(config, quant_config=quant_config)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
......@@ -299,11 +297,9 @@ class GemmaForCausalLM(nn.Module):
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return (sample_output, logits_output)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......
......@@ -37,7 +37,6 @@ from sglang.srt.layers.activation import GeluAndMul
from sglang.srt.layers.layernorm import GemmaRMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
......@@ -347,7 +346,6 @@ class Gemma2ForCausalLM(nn.Module):
self.quant_config = quant_config
self.model = Gemma2Model(config, cache_config, quant_config)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
......@@ -358,11 +356,9 @@ class Gemma2ForCausalLM(nn.Module):
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def get_attention_sliding_window_size(self):
return get_attention_sliding_window_size(self.config)
......
......@@ -35,7 +35,6 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.activation import get_act_fn
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
......@@ -262,7 +261,6 @@ class GPTBigCodeForCausalLM(nn.Module):
if lora_config:
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
......@@ -272,11 +270,9 @@ class GPTBigCodeForCausalLM(nn.Module):
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, input_metadata)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters(remove_duplicate=False))
......
......@@ -46,7 +46,6 @@ from sglang.srt.layers.fused_moe import FusedMoE
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
......@@ -298,7 +297,6 @@ class Grok1ForCausalLM(nn.Module):
self.model = Grok1Model(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
# Monkey patch _prepare_weights to load pre-sharded weights
setattr(DefaultModelLoader, "_prepare_weights", _prepare_presharded_weights)
......@@ -315,11 +313,9 @@ class Grok1ForCausalLM(nn.Module):
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......
......@@ -40,7 +40,6 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.model_executor.forward_batch_info import InputMetadata
......@@ -263,7 +262,6 @@ class InternLM2ForCausalLM(nn.Module):
self.model = InternLM2Model(config, quant_config)
self.output = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
@torch.no_grad()
def forward(
......@@ -274,11 +272,9 @@ class InternLM2ForCausalLM(nn.Module):
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.output.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......
......@@ -41,7 +41,6 @@ from sglang.srt.layers.activation import SiluAndMul
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.sampler import Sampler
from sglang.srt.layers.torchao_utils import torchao_quantize_param_data
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata
......@@ -305,7 +304,6 @@ class LlamaForCausalLM(nn.Module):
self.model = LlamaModel(config, quant_config=quant_config)
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
self.param_dict = dict(self.named_parameters())
......@@ -318,11 +316,9 @@ class LlamaForCausalLM(nn.Module):
input_embeds: torch.Tensor = None,
) -> LogitsProcessorOutput:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
logits_output = self.logits_processor(
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
sample_output = self.sampler(logits_output, input_metadata.sampling_info)
return sample_output, logits_output
def get_hidden_dim(self, module_name):
if module_name in ["q_proj", "o_proj", "qkv_proj"]:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment