Unverified Commit 36d5acfc authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Rename InputMetadata -> ForwardBatch (#1543)

parent 3f0fe08d
...@@ -46,7 +46,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor ...@@ -46,7 +46,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
# ROCm: flashinfer available later # ROCm: flashinfer available later
...@@ -281,7 +281,7 @@ class DeepseekV2Attention(nn.Module): ...@@ -281,7 +281,7 @@ class DeepseekV2Attention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0] q = self.q_a_proj(hidden_states)[0]
...@@ -314,7 +314,7 @@ class DeepseekV2Attention(nn.Module): ...@@ -314,7 +314,7 @@ class DeepseekV2Attention(nn.Module):
v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value=0).view( v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value=0).view(
-1, self.num_local_heads * 256 -1, self.num_local_heads * 256
) )
attn_output = self.attn(q, k, v, input_metadata) attn_output = self.attn(q, k, v, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, 256)[ attn_output = attn_output.view(-1, self.num_local_heads, 256)[
..., : self.v_head_dim ..., : self.v_head_dim
].reshape(-1, self.num_local_heads * self.v_head_dim) ].reshape(-1, self.num_local_heads * self.v_head_dim)
...@@ -433,7 +433,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -433,7 +433,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
q_len = hidden_states.shape[0] q_len = hidden_states.shape[0]
q_input = hidden_states.new_empty( q_input = hidden_states.new_empty(
...@@ -471,7 +471,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -471,7 +471,7 @@ class DeepseekV2AttentionMLA(nn.Module):
q_input[..., self.kv_lora_rank :] = q_pe q_input[..., self.kv_lora_rank :] = q_pe
k_input[..., self.kv_lora_rank :] = k_pe k_input[..., self.kv_lora_rank :] = k_pe
attn_output = self.attn(q_input, k_input, v_input, input_metadata) attn_output = self.attn(q_input, k_input, v_input, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
if self.w_vc.dtype == torch.float8_e4m3fn: if self.w_vc.dtype == torch.float8_e4m3fn:
...@@ -567,7 +567,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -567,7 +567,7 @@ class DeepseekV2DecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
...@@ -579,7 +579,7 @@ class DeepseekV2DecoderLayer(nn.Module): ...@@ -579,7 +579,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
input_metadata=input_metadata, forward_batch=forward_batch,
) )
# Fully Connected # Fully Connected
...@@ -623,14 +623,14 @@ class DeepseekV2Model(nn.Module): ...@@ -623,14 +623,14 @@ class DeepseekV2Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, hidden_states, input_metadata, residual positions, hidden_states, forward_batch, residual
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
...@@ -658,11 +658,11 @@ class DeepseekV2ForCausalLM(nn.Module): ...@@ -658,11 +658,11 @@ class DeepseekV2ForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata) hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -40,7 +40,7 @@ from sglang.srt.layers.linear import ( ...@@ -40,7 +40,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class ExaoneGatedMLP(nn.Module): class ExaoneGatedMLP(nn.Module):
...@@ -162,12 +162,12 @@ class ExaoneAttention(nn.Module): ...@@ -162,12 +162,12 @@ class ExaoneAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata) attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.out_proj(attn_output) output, _ = self.out_proj(attn_output)
return output return output
...@@ -220,7 +220,7 @@ class ExaoneDecoderLayer(nn.Module): ...@@ -220,7 +220,7 @@ class ExaoneDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -232,7 +232,7 @@ class ExaoneDecoderLayer(nn.Module): ...@@ -232,7 +232,7 @@ class ExaoneDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
input_metadata=input_metadata, forward_batch=forward_batch,
) )
# Fully Connected # Fully Connected
...@@ -270,7 +270,7 @@ class ExaoneModel(nn.Module): ...@@ -270,7 +270,7 @@ class ExaoneModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
if input_embeds is None: if input_embeds is None:
...@@ -283,7 +283,7 @@ class ExaoneModel(nn.Module): ...@@ -283,7 +283,7 @@ class ExaoneModel(nn.Module):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
input_metadata, forward_batch,
residual, residual,
) )
hidden_states, _ = self.ln_f(hidden_states, residual) hidden_states, _ = self.ln_f(hidden_states, residual)
...@@ -309,14 +309,14 @@ class ExaoneForCausalLM(nn.Module): ...@@ -309,14 +309,14 @@ class ExaoneForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> LogitsProcessorOutput: ) -> LogitsProcessorOutput:
hidden_states = self.transformer( hidden_states = self.transformer(
input_ids, positions, input_metadata, input_embeds input_ids, positions, forward_batch, input_embeds
) )
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -37,7 +37,7 @@ from sglang.srt.layers.linear import ( ...@@ -37,7 +37,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class GemmaMLP(nn.Module): class GemmaMLP(nn.Module):
...@@ -137,12 +137,12 @@ class GemmaAttention(nn.Module): ...@@ -137,12 +137,12 @@ class GemmaAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata) attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -180,7 +180,7 @@ class GemmaDecoderLayer(nn.Module): ...@@ -180,7 +180,7 @@ class GemmaDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -192,7 +192,7 @@ class GemmaDecoderLayer(nn.Module): ...@@ -192,7 +192,7 @@ class GemmaDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
input_metadata=input_metadata, forward_batch=forward_batch,
) )
# Fully Connected # Fully Connected
...@@ -226,7 +226,7 @@ class GemmaModel(nn.Module): ...@@ -226,7 +226,7 @@ class GemmaModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
if input_embeds is None: if input_embeds is None:
...@@ -243,7 +243,7 @@ class GemmaModel(nn.Module): ...@@ -243,7 +243,7 @@ class GemmaModel(nn.Module):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
input_metadata, forward_batch,
residual, residual,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
...@@ -293,12 +293,12 @@ class GemmaForCausalLM(nn.Module): ...@@ -293,12 +293,12 @@ class GemmaForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -37,7 +37,7 @@ from sglang.srt.layers.linear import ( ...@@ -37,7 +37,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
# Aligned with HF's implementation, using sliding window inclusive with the last token # Aligned with HF's implementation, using sliding window inclusive with the last token
...@@ -175,12 +175,12 @@ class Gemma2Attention(nn.Module): ...@@ -175,12 +175,12 @@ class Gemma2Attention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata) attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -230,7 +230,7 @@ class Gemma2DecoderLayer(nn.Module): ...@@ -230,7 +230,7 @@ class Gemma2DecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
if residual is None: if residual is None:
...@@ -241,7 +241,7 @@ class Gemma2DecoderLayer(nn.Module): ...@@ -241,7 +241,7 @@ class Gemma2DecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
input_metadata=input_metadata, forward_batch=forward_batch,
) )
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.post_attention_layernorm(hidden_states)
...@@ -286,7 +286,7 @@ class Gemma2Model(nn.Module): ...@@ -286,7 +286,7 @@ class Gemma2Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
if input_embeds is None: if input_embeds is None:
...@@ -302,7 +302,7 @@ class Gemma2Model(nn.Module): ...@@ -302,7 +302,7 @@ class Gemma2Model(nn.Module):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
input_metadata, forward_batch,
residual, residual,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
...@@ -352,12 +352,12 @@ class Gemma2ForCausalLM(nn.Module): ...@@ -352,12 +352,12 @@ class Gemma2ForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
) )
def get_attention_sliding_window_size(self): def get_attention_sliding_window_size(self):
......
...@@ -35,7 +35,7 @@ from sglang.srt.layers.linear import ( ...@@ -35,7 +35,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class GPTBigCodeAttention(nn.Module): class GPTBigCodeAttention(nn.Module):
...@@ -90,7 +90,7 @@ class GPTBigCodeAttention(nn.Module): ...@@ -90,7 +90,7 @@ class GPTBigCodeAttention(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.split( q, k, v = qkv.split(
...@@ -101,7 +101,7 @@ class GPTBigCodeAttention(nn.Module): ...@@ -101,7 +101,7 @@ class GPTBigCodeAttention(nn.Module):
], ],
dim=-1, dim=-1,
) )
attn_output = self.attn(q, k, v, input_metadata) attn_output = self.attn(q, k, v, forward_batch)
attn_output, _ = self.c_proj(attn_output) attn_output, _ = self.c_proj(attn_output)
return attn_output return attn_output
...@@ -160,12 +160,12 @@ class GPTBigCodeBlock(nn.Module): ...@@ -160,12 +160,12 @@ class GPTBigCodeBlock(nn.Module):
def forward( def forward(
self, self,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
residual = hidden_states residual = hidden_states
hidden_states = self.ln_1(hidden_states) hidden_states = self.ln_1(hidden_states)
attn_output = self.attn( attn_output = self.attn(
hidden_states=hidden_states, input_metadata=input_metadata hidden_states=hidden_states, forward_batch=forward_batch
) )
# residual connection # residual connection
hidden_states = attn_output + residual hidden_states = attn_output + residual
...@@ -214,7 +214,7 @@ class GPTBigCodeModel(nn.Module): ...@@ -214,7 +214,7 @@ class GPTBigCodeModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
position_ids: torch.Tensor, position_ids: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
inputs_embeds = self.wte(input_ids) inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids) position_embeds = self.wpe(position_ids)
...@@ -222,7 +222,7 @@ class GPTBigCodeModel(nn.Module): ...@@ -222,7 +222,7 @@ class GPTBigCodeModel(nn.Module):
for i in range(len(self.h)): for i in range(len(self.h)):
layer = self.h[i] layer = self.h[i]
hidden_states = layer(hidden_states, input_metadata) hidden_states = layer(hidden_states, forward_batch)
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
...@@ -267,11 +267,11 @@ class GPTBigCodeForCausalLM(nn.Module): ...@@ -267,11 +267,11 @@ class GPTBigCodeForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, input_metadata) hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -46,7 +46,7 @@ from sglang.srt.layers.linear import ( ...@@ -46,7 +46,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class Grok1MoE(nn.Module): class Grok1MoE(nn.Module):
...@@ -173,12 +173,12 @@ class Grok1Attention(nn.Module): ...@@ -173,12 +173,12 @@ class Grok1Attention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata) attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -219,7 +219,7 @@ class Grok1DecoderLayer(nn.Module): ...@@ -219,7 +219,7 @@ class Grok1DecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
hidden_states = ( hidden_states = (
...@@ -227,7 +227,7 @@ class Grok1DecoderLayer(nn.Module): ...@@ -227,7 +227,7 @@ class Grok1DecoderLayer(nn.Module):
self.self_attn( self.self_attn(
positions=positions, positions=positions,
hidden_states=self.pre_attn_norm(hidden_states), hidden_states=self.pre_attn_norm(hidden_states),
input_metadata=input_metadata, forward_batch=forward_batch,
) )
) )
+ hidden_states + hidden_states
...@@ -268,7 +268,7 @@ class Grok1Model(nn.Module): ...@@ -268,7 +268,7 @@ class Grok1Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
if input_embeds is None: if input_embeds is None:
...@@ -278,7 +278,7 @@ class Grok1Model(nn.Module): ...@@ -278,7 +278,7 @@ class Grok1Model(nn.Module):
hidden_states = input_embeds hidden_states = input_embeds
for i in range(len(self.layers)): for i in range(len(self.layers)):
hidden_states = self.layers[i](positions, hidden_states, input_metadata) hidden_states = self.layers[i](positions, hidden_states, forward_batch)
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
hidden_states.mul_(self.config.output_multiplier_scale) hidden_states.mul_(self.config.output_multiplier_scale)
return hidden_states return hidden_states
...@@ -309,12 +309,12 @@ class Grok1ForCausalLM(nn.Module): ...@@ -309,12 +309,12 @@ class Grok1ForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -40,7 +40,7 @@ from sglang.srt.layers.linear import ( ...@@ -40,7 +40,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class InternLM2MLP(nn.Module): class InternLM2MLP(nn.Module):
...@@ -137,12 +137,12 @@ class InternLM2Attention(nn.Module): ...@@ -137,12 +137,12 @@ class InternLM2Attention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.wqkv(hidden_states) qkv, _ = self.wqkv(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata) attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.wo(attn_output) output, _ = self.wo(attn_output)
return output return output
...@@ -182,7 +182,7 @@ class InternLMDecoderLayer(nn.Module): ...@@ -182,7 +182,7 @@ class InternLMDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -194,7 +194,7 @@ class InternLMDecoderLayer(nn.Module): ...@@ -194,7 +194,7 @@ class InternLMDecoderLayer(nn.Module):
hidden_states = self.attention( hidden_states = self.attention(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
input_metadata=input_metadata, forward_batch=forward_batch,
) )
# Fully Connected # Fully Connected
...@@ -229,7 +229,7 @@ class InternLM2Model(nn.Module): ...@@ -229,7 +229,7 @@ class InternLM2Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
if input_embeds is None: if input_embeds is None:
...@@ -242,7 +242,7 @@ class InternLM2Model(nn.Module): ...@@ -242,7 +242,7 @@ class InternLM2Model(nn.Module):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
input_metadata, forward_batch,
residual, residual,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
...@@ -268,12 +268,12 @@ class InternLM2ForCausalLM(nn.Module): ...@@ -268,12 +268,12 @@ class InternLM2ForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.output.weight, input_metadata input_ids, hidden_states, self.output.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -43,7 +43,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig ...@@ -43,7 +43,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class LlamaMLP(nn.Module): class LlamaMLP(nn.Module):
...@@ -162,12 +162,12 @@ class LlamaAttention(nn.Module): ...@@ -162,12 +162,12 @@ class LlamaAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata) attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -221,7 +221,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -221,7 +221,7 @@ class LlamaDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -233,7 +233,7 @@ class LlamaDecoderLayer(nn.Module): ...@@ -233,7 +233,7 @@ class LlamaDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
input_metadata=input_metadata, forward_batch=forward_batch,
) )
# Fully Connected # Fully Connected
...@@ -270,7 +270,7 @@ class LlamaModel(nn.Module): ...@@ -270,7 +270,7 @@ class LlamaModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
if input_embeds is None: if input_embeds is None:
...@@ -283,7 +283,7 @@ class LlamaModel(nn.Module): ...@@ -283,7 +283,7 @@ class LlamaModel(nn.Module):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
input_metadata, forward_batch,
residual, residual,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
...@@ -310,12 +310,12 @@ class LlamaForCausalLM(nn.Module): ...@@ -310,12 +310,12 @@ class LlamaForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> LogitsProcessorOutput: ) -> LogitsProcessorOutput:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def get_hidden_dim(self, module_name): def get_hidden_dim(self, module_name):
......
...@@ -23,7 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -23,7 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
...@@ -50,18 +50,18 @@ class LlamaForClassification(nn.Module): ...@@ -50,18 +50,18 @@ class LlamaForClassification(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
is_eos_token = input_ids == self.eos_token_id is_eos_token = input_ids == self.eos_token_id
hidden_states = hidden_states[is_eos_token] hidden_states = hidden_states[is_eos_token]
scores = self.classification_head(hidden_states) scores = self.classification_head(hidden_states)
if scores.shape[0] != input_metadata.batch_size: if scores.shape[0] != forward_batch.batch_size:
print("Warning: the EOS tokens are missing in some sentences.") print("Warning: the EOS tokens are missing in some sentences.")
scores = torch.ones( scores = torch.ones(
(input_metadata.batch_size, self.config.classification_out_size) (forward_batch.batch_size, self.config.classification_out_size)
).to(input_ids.device) ).to(input_ids.device)
logits_output = LogitsProcessorOutput( logits_output = LogitsProcessorOutput(
......
...@@ -6,7 +6,7 @@ from transformers import LlamaConfig ...@@ -6,7 +6,7 @@ from transformers import LlamaConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.model_executor.model_runner import InputMetadata from sglang.srt.model_executor.model_runner import ForwardBatch
from sglang.srt.models.llama import LlamaModel from sglang.srt.models.llama import LlamaModel
...@@ -26,15 +26,15 @@ class LlamaEmbeddingModel(nn.Module): ...@@ -26,15 +26,15 @@ class LlamaEmbeddingModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
get_embedding: bool = True, get_embedding: bool = True,
) -> EmbeddingPoolerOutput: ) -> EmbeddingPoolerOutput:
assert ( assert (
get_embedding get_embedding
), "LlamaEmbeddingModel / MistralModel is only used for embedding" ), "LlamaEmbeddingModel / MistralModel is only used for embedding"
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.pooler(hidden_states, input_metadata) return self.pooler(hidden_states, forward_batch)
def load_weights( def load_weights(
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
......
...@@ -24,7 +24,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -24,7 +24,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
...@@ -51,13 +51,13 @@ class LlamaForSequenceClassification(nn.Module): ...@@ -51,13 +51,13 @@ class LlamaForSequenceClassification(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> EmbeddingPoolerOutput: ) -> EmbeddingPoolerOutput:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
scores = self.score(hidden_states) scores = self.score(hidden_states)
return self.pooler(scores, input_metadata) return self.pooler(scores, forward_batch)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
...@@ -102,19 +102,19 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific ...@@ -102,19 +102,19 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
get_embedding: bool = True, get_embedding: bool = True,
) -> EmbeddingPoolerOutput: ) -> EmbeddingPoolerOutput:
assert ( assert (
get_embedding get_embedding
), "LlamaForSequenceClassification is only used for embedding" ), "LlamaForSequenceClassification is only used for embedding"
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
logits = self.score(hidden_states) logits = self.score(hidden_states)
weights = self.weights(hidden_states) weights = self.weights(hidden_states)
pooled_logits = self.pooler(logits, input_metadata).embeddings pooled_logits = self.pooler(logits, forward_batch).embeddings
pooled_weights = self.pooler(weights, input_metadata).embeddings pooled_weights = self.pooler(weights, forward_batch).embeddings
rews = pooled_logits.view(-1, self.num_labels // 2, 2)[:, :, 0].view( rews = pooled_logits.view(-1, self.num_labels // 2, 2)[:, :, 0].view(
-1, self.num_labels // 2 -1, self.num_labels // 2
......
...@@ -41,7 +41,7 @@ from sglang.srt.mm_utils import ( ...@@ -41,7 +41,7 @@ from sglang.srt.mm_utils import (
unpad_image, unpad_image,
unpad_image_shape, unpad_image_shape,
) )
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.llama import LlamaForCausalLM from sglang.srt.models.llama import LlamaForCausalLM
from sglang.srt.models.mistral import MistralForCausalLM from sglang.srt.models.mistral import MistralForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM
...@@ -130,12 +130,12 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -130,12 +130,12 @@ class LlavaBaseForCausalLM(nn.Module):
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
image_inputs = input_metadata.image_inputs image_inputs = forward_batch.image_inputs
if input_metadata.forward_mode.is_extend(): if forward_batch.forward_mode.is_extend():
bs = input_metadata.batch_size bs = forward_batch.batch_size
# Got List[List[str]] extend it to List[str] # Got List[List[str]] extend it to List[str]
# The length of the List should be equal to batch size # The length of the List should be equal to batch size
modalities_list = [] modalities_list = []
...@@ -151,7 +151,7 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -151,7 +151,7 @@ class LlavaBaseForCausalLM(nn.Module):
# Embed text inputs # Embed text inputs
input_embeds = self.language_model.model.embed_tokens(input_ids) input_embeds = self.language_model.model.embed_tokens(input_ids)
start_positions = positions[input_metadata.extend_start_loc].cpu().numpy() start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
need_vision = start_positions <= np.array(max_image_offset) need_vision = start_positions <= np.array(max_image_offset)
if need_vision.any(): if need_vision.any():
...@@ -348,8 +348,8 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -348,8 +348,8 @@ class LlavaBaseForCausalLM(nn.Module):
image_features = new_image_features image_features = new_image_features
# Fill in the placeholder for the image # Fill in the placeholder for the image
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy() extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy() prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
pt = 0 pt = 0
for i in range(bs): for i in range(bs):
if not need_vision[i]: if not need_vision[i]:
...@@ -379,10 +379,10 @@ class LlavaBaseForCausalLM(nn.Module): ...@@ -379,10 +379,10 @@ class LlavaBaseForCausalLM(nn.Module):
pt += 1 pt += 1
return self.language_model( return self.language_model(
input_ids, positions, input_metadata, input_embeds=input_embeds input_ids, positions, forward_batch, input_embeds=input_embeds
) )
elif input_metadata.forward_mode.is_decode(): elif forward_batch.forward_mode.is_decode():
return self.language_model(input_ids, positions, input_metadata) return self.language_model(input_ids, positions, forward_batch)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Load clip vision model by cfg['mm_vision_tower']: # Load clip vision model by cfg['mm_vision_tower']:
......
...@@ -27,7 +27,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -27,7 +27,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import ImageInputs from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.llama import LlamaForCausalLM from sglang.srt.models.llama import LlamaForCausalLM
...@@ -108,11 +108,11 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -108,11 +108,11 @@ class LlavaVidForCausalLM(nn.Module):
self, self,
input_ids: torch.LongTensor, input_ids: torch.LongTensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
image_inputs = input_metadata.image_inputs image_inputs = forward_batch.image_inputs
if input_metadata.forward_mode.is_extend(): if forward_batch.forward_mode.is_extend():
bs = input_metadata.batch_size bs = forward_batch.batch_size
# Embed text inputs # Embed text inputs
input_embeds = self.language_model.model.embed_tokens(input_ids) input_embeds = self.language_model.model.embed_tokens(input_ids)
...@@ -124,7 +124,7 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -124,7 +124,7 @@ class LlavaVidForCausalLM(nn.Module):
max_image_offset.append(max(im.image_offsets)) max_image_offset.append(max(im.image_offsets))
else: else:
max_image_offset.append(-1) max_image_offset.append(-1)
start_positions = positions[input_metadata.extend_start_loc].cpu().numpy() start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
need_vision = start_positions <= np.array(max_image_offset) need_vision = start_positions <= np.array(max_image_offset)
if need_vision.any(): if need_vision.any():
...@@ -169,8 +169,8 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -169,8 +169,8 @@ class LlavaVidForCausalLM(nn.Module):
image_features = new_image_features image_features = new_image_features
# Fill in the placeholder for the image # Fill in the placeholder for the image
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy() extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy() prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
pt = 0 pt = 0
for i in range(bs): for i in range(bs):
if not need_vision[i]: if not need_vision[i]:
...@@ -200,10 +200,10 @@ class LlavaVidForCausalLM(nn.Module): ...@@ -200,10 +200,10 @@ class LlavaVidForCausalLM(nn.Module):
pt += 1 pt += 1
return self.language_model( return self.language_model(
input_ids, positions, input_metadata, input_embeds=input_embeds input_ids, positions, forward_batch, input_embeds=input_embeds
) )
elif input_metadata.forward_mode.is_decode(): elif forward_batch.forward_mode.is_decode():
return self.language_model(input_ids, positions, input_metadata) return self.language_model(input_ids, positions, forward_batch)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Load clip vision model by cfg['mm_vision_tower']: # Load clip vision model by cfg['mm_vision_tower']:
......
...@@ -39,7 +39,7 @@ from sglang.srt.layers.linear import ( ...@@ -39,7 +39,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class MiniCPMMLP(nn.Module): class MiniCPMMLP(nn.Module):
...@@ -148,7 +148,7 @@ class MiniCPMAttention(nn.Module): ...@@ -148,7 +148,7 @@ class MiniCPMAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
...@@ -156,7 +156,7 @@ class MiniCPMAttention(nn.Module): ...@@ -156,7 +156,7 @@ class MiniCPMAttention(nn.Module):
q, k = q.float(), k.float() q, k = q.float(), k.float()
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
q, k = q.to(orig_dtype), k.to(orig_dtype) q, k = q.to(orig_dtype), k.to(orig_dtype)
attn_output = self.attn(q, k, v, input_metadata) attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -199,7 +199,7 @@ class MiniCPMDecoderLayer(nn.Module): ...@@ -199,7 +199,7 @@ class MiniCPMDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -208,7 +208,7 @@ class MiniCPMDecoderLayer(nn.Module): ...@@ -208,7 +208,7 @@ class MiniCPMDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
input_metadata=input_metadata, forward_batch=forward_batch,
) )
hidden_states = residual + hidden_states * ( hidden_states = residual + hidden_states * (
self.config.scale_depth / math.sqrt(self.config.num_hidden_layers) self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)
...@@ -252,7 +252,7 @@ class MiniCPMModel(nn.Module): ...@@ -252,7 +252,7 @@ class MiniCPMModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
if input_embeds is None: if input_embeds is None:
...@@ -266,7 +266,7 @@ class MiniCPMModel(nn.Module): ...@@ -266,7 +266,7 @@ class MiniCPMModel(nn.Module):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
input_metadata, forward_batch,
residual, residual,
) )
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
...@@ -303,19 +303,19 @@ class MiniCPMForCausalLM(nn.Module): ...@@ -303,19 +303,19 @@ class MiniCPMForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
if input_embeds is not None: if input_embeds is not None:
input_embeds = input_embeds * self.config.scale_emb input_embeds = input_embeds * self.config.scale_emb
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
hidden_states = hidden_states / self.scale_width hidden_states = hidden_states / self.scale_width
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
lm_head_weight = self.model.embed_tokens.weight lm_head_weight = self.model.embed_tokens.weight
else: else:
lm_head_weight = self.lm_head.weight lm_head_weight = self.lm_head.weight
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, lm_head_weight, input_metadata input_ids, hidden_states, lm_head_weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -42,7 +42,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor ...@@ -42,7 +42,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
# ROCm: flashinfer available later # ROCm: flashinfer available later
...@@ -193,7 +193,7 @@ class MiniCPM3Attention(nn.Module): ...@@ -193,7 +193,7 @@ class MiniCPM3Attention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
if self.q_lora_rank is not None: if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0] q = self.q_a_proj(hidden_states)[0]
...@@ -230,7 +230,7 @@ class MiniCPM3Attention(nn.Module): ...@@ -230,7 +230,7 @@ class MiniCPM3Attention(nn.Module):
v = torch.nn.functional.pad(v, [0, 128 - self.v_head_dim], value=0).view( v = torch.nn.functional.pad(v, [0, 128 - self.v_head_dim], value=0).view(
-1, self.num_local_heads * 128 -1, self.num_local_heads * 128
) )
attn_output = self.attn(q, k, v, input_metadata) attn_output = self.attn(q, k, v, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, 128)[ attn_output = attn_output.view(-1, self.num_local_heads, 128)[
..., : self.v_head_dim ..., : self.v_head_dim
].reshape(-1, self.num_local_heads * self.v_head_dim) ].reshape(-1, self.num_local_heads * self.v_head_dim)
...@@ -341,7 +341,7 @@ class MiniCPM3AttentionMLA(nn.Module): ...@@ -341,7 +341,7 @@ class MiniCPM3AttentionMLA(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
q_len = hidden_states.shape[0] q_len = hidden_states.shape[0]
q_input = hidden_states.new_empty( q_input = hidden_states.new_empty(
...@@ -383,7 +383,7 @@ class MiniCPM3AttentionMLA(nn.Module): ...@@ -383,7 +383,7 @@ class MiniCPM3AttentionMLA(nn.Module):
q_input[..., self.kv_lora_rank :] = q_pe q_input[..., self.kv_lora_rank :] = q_pe
k_input[..., self.kv_lora_rank :] = k_pe k_input[..., self.kv_lora_rank :] = k_pe
attn_output = self.attn(q_input, k_input, v_input, input_metadata) attn_output = self.attn(q_input, k_input, v_input, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank) attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
if self.w_vc.dtype == torch.float8_e4m3fn: if self.w_vc.dtype == torch.float8_e4m3fn:
...@@ -472,7 +472,7 @@ class MiniCPM3DecoderLayer(nn.Module): ...@@ -472,7 +472,7 @@ class MiniCPM3DecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -481,7 +481,7 @@ class MiniCPM3DecoderLayer(nn.Module): ...@@ -481,7 +481,7 @@ class MiniCPM3DecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
input_metadata=input_metadata, forward_batch=forward_batch,
) )
hidden_states = residual + hidden_states * ( hidden_states = residual + hidden_states * (
self.config.scale_depth / math.sqrt(self.config.num_hidden_layers) self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)
...@@ -528,7 +528,7 @@ class MiniCPM3Model(nn.Module): ...@@ -528,7 +528,7 @@ class MiniCPM3Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
if input_embeds is None: if input_embeds is None:
...@@ -542,7 +542,7 @@ class MiniCPM3Model(nn.Module): ...@@ -542,7 +542,7 @@ class MiniCPM3Model(nn.Module):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
input_metadata, forward_batch,
residual, residual,
) )
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
...@@ -581,19 +581,19 @@ class MiniCPM3ForCausalLM(nn.Module): ...@@ -581,19 +581,19 @@ class MiniCPM3ForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
if input_embeds is not None: if input_embeds is not None:
input_embeds = input_embeds * self.config.scale_emb input_embeds = input_embeds * self.config.scale_emb
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
hidden_states = hidden_states / self.scale_width hidden_states = hidden_states / self.scale_width
if self.config.tie_word_embeddings: if self.config.tie_word_embeddings:
lm_head_weight = self.model.embed_tokens.weight lm_head_weight = self.model.embed_tokens.weight
else: else:
lm_head_weight = self.lm_head.weight lm_head_weight = self.lm_head.weight
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, lm_head_weight, input_metadata input_ids, hidden_states, lm_head_weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -43,7 +43,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig ...@@ -43,7 +43,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class MixtralMoE(nn.Module): class MixtralMoE(nn.Module):
...@@ -171,12 +171,12 @@ class MixtralAttention(nn.Module): ...@@ -171,12 +171,12 @@ class MixtralAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata) attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -220,7 +220,7 @@ class MixtralDecoderLayer(nn.Module): ...@@ -220,7 +220,7 @@ class MixtralDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
...@@ -232,7 +232,7 @@ class MixtralDecoderLayer(nn.Module): ...@@ -232,7 +232,7 @@ class MixtralDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
input_metadata=input_metadata, forward_batch=forward_batch,
) )
# Fully Connected # Fully Connected
...@@ -270,7 +270,7 @@ class MixtralModel(nn.Module): ...@@ -270,7 +270,7 @@ class MixtralModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
if input_embeds is None: if input_embeds is None:
...@@ -281,7 +281,7 @@ class MixtralModel(nn.Module): ...@@ -281,7 +281,7 @@ class MixtralModel(nn.Module):
for i in range(len(self.layers)): for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, hidden_states, input_metadata, residual positions, hidden_states, forward_batch, residual
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
...@@ -307,12 +307,12 @@ class MixtralForCausalLM(nn.Module): ...@@ -307,12 +307,12 @@ class MixtralForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -45,7 +45,7 @@ from sglang.srt.layers.linear import ( ...@@ -45,7 +45,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class MixtralMLP(nn.Module): class MixtralMLP(nn.Module):
...@@ -216,12 +216,12 @@ class MixtralAttention(nn.Module): ...@@ -216,12 +216,12 @@ class MixtralAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata) attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -256,7 +256,7 @@ class MixtralDecoderLayer(nn.Module): ...@@ -256,7 +256,7 @@ class MixtralDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
...@@ -268,7 +268,7 @@ class MixtralDecoderLayer(nn.Module): ...@@ -268,7 +268,7 @@ class MixtralDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
input_metadata=input_metadata, forward_batch=forward_batch,
) )
# Fully Connected # Fully Connected
...@@ -303,7 +303,7 @@ class MixtralModel(nn.Module): ...@@ -303,7 +303,7 @@ class MixtralModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
if input_embeds is None: if input_embeds is None:
...@@ -314,7 +314,7 @@ class MixtralModel(nn.Module): ...@@ -314,7 +314,7 @@ class MixtralModel(nn.Module):
for i in range(len(self.layers)): for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, hidden_states, input_metadata, residual positions, hidden_states, forward_batch, residual
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
...@@ -339,12 +339,12 @@ class QuantMixtralForCausalLM(nn.Module): ...@@ -339,12 +339,12 @@ class QuantMixtralForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -48,7 +48,7 @@ from sglang.srt.layers.layernorm import RMSNorm ...@@ -48,7 +48,7 @@ from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class OlmoeMoE(nn.Module): class OlmoeMoE(nn.Module):
...@@ -175,13 +175,13 @@ class OlmoeAttention(nn.Module): ...@@ -175,13 +175,13 @@ class OlmoeAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous()) q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous())
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata) attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -225,7 +225,7 @@ class OlmoeDecoderLayer(nn.Module): ...@@ -225,7 +225,7 @@ class OlmoeDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
...@@ -238,7 +238,7 @@ class OlmoeDecoderLayer(nn.Module): ...@@ -238,7 +238,7 @@ class OlmoeDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
input_metadata=input_metadata, forward_batch=forward_batch,
) )
# Fully Connected # Fully Connected
...@@ -274,7 +274,7 @@ class OlmoeModel(nn.Module): ...@@ -274,7 +274,7 @@ class OlmoeModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
if input_embeds is None: if input_embeds is None:
...@@ -285,7 +285,7 @@ class OlmoeModel(nn.Module): ...@@ -285,7 +285,7 @@ class OlmoeModel(nn.Module):
for i in range(len(self.layers)): for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, hidden_states, input_metadata, residual positions, hidden_states, forward_batch, residual
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
...@@ -314,12 +314,12 @@ class OlmoeForCausalLM(nn.Module): ...@@ -314,12 +314,12 @@ class OlmoeForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -39,7 +39,7 @@ from sglang.srt.layers.linear import ( ...@@ -39,7 +39,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class QWenMLP(nn.Module): class QWenMLP(nn.Module):
...@@ -133,12 +133,12 @@ class QWenAttention(nn.Module): ...@@ -133,12 +133,12 @@ class QWenAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states) qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1) q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata) attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.c_proj(attn_output) output, _ = self.c_proj(attn_output)
return output return output
...@@ -177,7 +177,7 @@ class QWenBlock(nn.Module): ...@@ -177,7 +177,7 @@ class QWenBlock(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
residual = hidden_states residual = hidden_states
...@@ -185,7 +185,7 @@ class QWenBlock(nn.Module): ...@@ -185,7 +185,7 @@ class QWenBlock(nn.Module):
hidden_states = self.attn( hidden_states = self.attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
input_metadata=input_metadata, forward_batch=forward_batch,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -224,7 +224,7 @@ class QWenModel(nn.Module): ...@@ -224,7 +224,7 @@ class QWenModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.wte(input_ids) hidden_states = self.wte(input_ids)
for i in range(len(self.h)): for i in range(len(self.h)):
...@@ -232,7 +232,7 @@ class QWenModel(nn.Module): ...@@ -232,7 +232,7 @@ class QWenModel(nn.Module):
hidden_states = layer( hidden_states = layer(
positions, positions,
hidden_states, hidden_states,
input_metadata, forward_batch,
) )
hidden_states = self.ln_f(hidden_states) hidden_states = self.ln_f(hidden_states)
return hidden_states return hidden_states
...@@ -257,11 +257,11 @@ class QWenLMHeadModel(nn.Module): ...@@ -257,11 +257,11 @@ class QWenLMHeadModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
): ):
hidden_states = self.transformer(input_ids, positions, input_metadata) hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -40,7 +40,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor ...@@ -40,7 +40,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
Qwen2Config = None Qwen2Config = None
...@@ -149,12 +149,12 @@ class Qwen2Attention(nn.Module): ...@@ -149,12 +149,12 @@ class Qwen2Attention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata) attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -196,7 +196,7 @@ class Qwen2DecoderLayer(nn.Module): ...@@ -196,7 +196,7 @@ class Qwen2DecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -208,7 +208,7 @@ class Qwen2DecoderLayer(nn.Module): ...@@ -208,7 +208,7 @@ class Qwen2DecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
input_metadata=input_metadata, forward_batch=forward_batch,
) )
# Fully Connected # Fully Connected
...@@ -243,7 +243,7 @@ class Qwen2Model(nn.Module): ...@@ -243,7 +243,7 @@ class Qwen2Model(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
if input_embeds is None: if input_embeds is None:
...@@ -256,7 +256,7 @@ class Qwen2Model(nn.Module): ...@@ -256,7 +256,7 @@ class Qwen2Model(nn.Module):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
input_metadata, forward_batch,
residual, residual,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
...@@ -283,17 +283,17 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -283,17 +283,17 @@ class Qwen2ForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
get_embedding: bool = False, get_embedding: bool = False,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
if not get_embedding: if not get_embedding:
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
else: else:
return self.pooler(hidden_states, input_metadata) return self.pooler(hidden_states, forward_batch)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [ stacked_params_mapping = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment