Unverified Commit 36d5acfc authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Rename InputMetadata -> ForwardBatch (#1543)

parent 3f0fe08d
......@@ -46,7 +46,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_hip
# ROCm: flashinfer available later
......@@ -281,7 +281,7 @@ class DeepseekV2Attention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
......@@ -314,7 +314,7 @@ class DeepseekV2Attention(nn.Module):
v = torch.nn.functional.pad(v, [0, 256 - self.v_head_dim], value=0).view(
-1, self.num_local_heads * 256
)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, 256)[
..., : self.v_head_dim
].reshape(-1, self.num_local_heads * self.v_head_dim)
......@@ -433,7 +433,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
q_len = hidden_states.shape[0]
q_input = hidden_states.new_empty(
......@@ -471,7 +471,7 @@ class DeepseekV2AttentionMLA(nn.Module):
q_input[..., self.kv_lora_rank :] = q_pe
k_input[..., self.kv_lora_rank :] = k_pe
attn_output = self.attn(q_input, k_input, v_input, input_metadata)
attn_output = self.attn(q_input, k_input, v_input, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
if self.w_vc.dtype == torch.float8_e4m3fn:
......@@ -567,7 +567,7 @@ class DeepseekV2DecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
......@@ -579,7 +579,7 @@ class DeepseekV2DecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
# Fully Connected
......@@ -623,14 +623,14 @@ class DeepseekV2Model(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids)
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, input_metadata, residual
positions, hidden_states, forward_batch, residual
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
......@@ -658,11 +658,11 @@ class DeepseekV2ForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata)
hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
......@@ -40,7 +40,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class ExaoneGatedMLP(nn.Module):
......@@ -162,12 +162,12 @@ class ExaoneAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.out_proj(attn_output)
return output
......@@ -220,7 +220,7 @@ class ExaoneDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
......@@ -232,7 +232,7 @@ class ExaoneDecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
# Fully Connected
......@@ -270,7 +270,7 @@ class ExaoneModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
......@@ -283,7 +283,7 @@ class ExaoneModel(nn.Module):
hidden_states, residual = layer(
positions,
hidden_states,
input_metadata,
forward_batch,
residual,
)
hidden_states, _ = self.ln_f(hidden_states, residual)
......@@ -309,14 +309,14 @@ class ExaoneForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> LogitsProcessorOutput:
hidden_states = self.transformer(
input_ids, positions, input_metadata, input_embeds
input_ids, positions, forward_batch, input_embeds
)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
......@@ -37,7 +37,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class GemmaMLP(nn.Module):
......@@ -137,12 +137,12 @@ class GemmaAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
......@@ -180,7 +180,7 @@ class GemmaDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
......@@ -192,7 +192,7 @@ class GemmaDecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
# Fully Connected
......@@ -226,7 +226,7 @@ class GemmaModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
......@@ -243,7 +243,7 @@ class GemmaModel(nn.Module):
hidden_states, residual = layer(
positions,
hidden_states,
input_metadata,
forward_batch,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
......@@ -293,12 +293,12 @@ class GemmaForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
......@@ -37,7 +37,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
# Aligned with HF's implementation, using sliding window inclusive with the last token
......@@ -175,12 +175,12 @@ class Gemma2Attention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
......@@ -230,7 +230,7 @@ class Gemma2DecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
if residual is None:
......@@ -241,7 +241,7 @@ class Gemma2DecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
hidden_states = self.post_attention_layernorm(hidden_states)
......@@ -286,7 +286,7 @@ class Gemma2Model(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
......@@ -302,7 +302,7 @@ class Gemma2Model(nn.Module):
hidden_states, residual = layer(
positions,
hidden_states,
input_metadata,
forward_batch,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
......@@ -352,12 +352,12 @@ class Gemma2ForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
input_ids, hidden_states, self.model.embed_tokens.weight, forward_batch
)
def get_attention_sliding_window_size(self):
......
......@@ -35,7 +35,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class GPTBigCodeAttention(nn.Module):
......@@ -90,7 +90,7 @@ class GPTBigCodeAttention(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.split(
......@@ -101,7 +101,7 @@ class GPTBigCodeAttention(nn.Module):
],
dim=-1,
)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
attn_output, _ = self.c_proj(attn_output)
return attn_output
......@@ -160,12 +160,12 @@ class GPTBigCodeBlock(nn.Module):
def forward(
self,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
residual = hidden_states
hidden_states = self.ln_1(hidden_states)
attn_output = self.attn(
hidden_states=hidden_states, input_metadata=input_metadata
hidden_states=hidden_states, forward_batch=forward_batch
)
# residual connection
hidden_states = attn_output + residual
......@@ -214,7 +214,7 @@ class GPTBigCodeModel(nn.Module):
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
inputs_embeds = self.wte(input_ids)
position_embeds = self.wpe(position_ids)
......@@ -222,7 +222,7 @@ class GPTBigCodeModel(nn.Module):
for i in range(len(self.h)):
layer = self.h[i]
hidden_states = layer(hidden_states, input_metadata)
hidden_states = layer(hidden_states, forward_batch)
hidden_states = self.ln_f(hidden_states)
return hidden_states
......@@ -267,11 +267,11 @@ class GPTBigCodeForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.transformer(input_ids, positions, input_metadata)
hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
......@@ -46,7 +46,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class Grok1MoE(nn.Module):
......@@ -173,12 +173,12 @@ class Grok1Attention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
......@@ -219,7 +219,7 @@ class Grok1DecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
# Self Attention
hidden_states = (
......@@ -227,7 +227,7 @@ class Grok1DecoderLayer(nn.Module):
self.self_attn(
positions=positions,
hidden_states=self.pre_attn_norm(hidden_states),
input_metadata=input_metadata,
forward_batch=forward_batch,
)
)
+ hidden_states
......@@ -268,7 +268,7 @@ class Grok1Model(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
......@@ -278,7 +278,7 @@ class Grok1Model(nn.Module):
hidden_states = input_embeds
for i in range(len(self.layers)):
hidden_states = self.layers[i](positions, hidden_states, input_metadata)
hidden_states = self.layers[i](positions, hidden_states, forward_batch)
hidden_states = self.norm(hidden_states)
hidden_states.mul_(self.config.output_multiplier_scale)
return hidden_states
......@@ -309,12 +309,12 @@ class Grok1ForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
......@@ -40,7 +40,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class InternLM2MLP(nn.Module):
......@@ -137,12 +137,12 @@ class InternLM2Attention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.wqkv(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.wo(attn_output)
return output
......@@ -182,7 +182,7 @@ class InternLMDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
......@@ -194,7 +194,7 @@ class InternLMDecoderLayer(nn.Module):
hidden_states = self.attention(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
# Fully Connected
......@@ -229,7 +229,7 @@ class InternLM2Model(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
......@@ -242,7 +242,7 @@ class InternLM2Model(nn.Module):
hidden_states, residual = layer(
positions,
hidden_states,
input_metadata,
forward_batch,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
......@@ -268,12 +268,12 @@ class InternLM2ForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.output.weight, input_metadata
input_ids, hidden_states, self.output.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
......@@ -43,7 +43,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class LlamaMLP(nn.Module):
......@@ -162,12 +162,12 @@ class LlamaAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
......@@ -221,7 +221,7 @@ class LlamaDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
......@@ -233,7 +233,7 @@ class LlamaDecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
# Fully Connected
......@@ -270,7 +270,7 @@ class LlamaModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
......@@ -283,7 +283,7 @@ class LlamaModel(nn.Module):
hidden_states, residual = layer(
positions,
hidden_states,
input_metadata,
forward_batch,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
......@@ -310,12 +310,12 @@ class LlamaForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> LogitsProcessorOutput:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def get_hidden_dim(self, module_name):
......
......@@ -23,7 +23,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
......@@ -50,18 +50,18 @@ class LlamaForClassification(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
is_eos_token = input_ids == self.eos_token_id
hidden_states = hidden_states[is_eos_token]
scores = self.classification_head(hidden_states)
if scores.shape[0] != input_metadata.batch_size:
if scores.shape[0] != forward_batch.batch_size:
print("Warning: the EOS tokens are missing in some sentences.")
scores = torch.ones(
(input_metadata.batch_size, self.config.classification_out_size)
(forward_batch.batch_size, self.config.classification_out_size)
).to(input_ids.device)
logits_output = LogitsProcessorOutput(
......
......@@ -6,7 +6,7 @@ from transformers import LlamaConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.model_executor.model_runner import InputMetadata
from sglang.srt.model_executor.model_runner import ForwardBatch
from sglang.srt.models.llama import LlamaModel
......@@ -26,15 +26,15 @@ class LlamaEmbeddingModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
get_embedding: bool = True,
) -> EmbeddingPoolerOutput:
assert (
get_embedding
), "LlamaEmbeddingModel / MistralModel is only used for embedding"
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.pooler(hidden_states, input_metadata)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.pooler(hidden_states, forward_batch)
def load_weights(
self, weights: Iterable[Tuple[str, torch.Tensor]], name=None, loaded_weight=None
......
......@@ -24,7 +24,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
......@@ -51,13 +51,13 @@ class LlamaForSequenceClassification(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> EmbeddingPoolerOutput:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
scores = self.score(hidden_states)
return self.pooler(scores, input_metadata)
return self.pooler(scores, forward_batch)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
params_dict = dict(self.named_parameters())
......@@ -102,19 +102,19 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
get_embedding: bool = True,
) -> EmbeddingPoolerOutput:
assert (
get_embedding
), "LlamaForSequenceClassification is only used for embedding"
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
logits = self.score(hidden_states)
weights = self.weights(hidden_states)
pooled_logits = self.pooler(logits, input_metadata).embeddings
pooled_weights = self.pooler(weights, input_metadata).embeddings
pooled_logits = self.pooler(logits, forward_batch).embeddings
pooled_weights = self.pooler(weights, forward_batch).embeddings
rews = pooled_logits.view(-1, self.num_labels // 2, 2)[:, :, 0].view(
-1, self.num_labels // 2
......
......@@ -41,7 +41,7 @@ from sglang.srt.mm_utils import (
unpad_image,
unpad_image_shape,
)
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.llama import LlamaForCausalLM
from sglang.srt.models.mistral import MistralForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
......@@ -130,12 +130,12 @@ class LlavaBaseForCausalLM(nn.Module):
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
image_inputs = input_metadata.image_inputs
image_inputs = forward_batch.image_inputs
if input_metadata.forward_mode.is_extend():
bs = input_metadata.batch_size
if forward_batch.forward_mode.is_extend():
bs = forward_batch.batch_size
# Got List[List[str]] extend it to List[str]
# The length of the List should be equal to batch size
modalities_list = []
......@@ -151,7 +151,7 @@ class LlavaBaseForCausalLM(nn.Module):
# Embed text inputs
input_embeds = self.language_model.model.embed_tokens(input_ids)
start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
need_vision = start_positions <= np.array(max_image_offset)
if need_vision.any():
......@@ -348,8 +348,8 @@ class LlavaBaseForCausalLM(nn.Module):
image_features = new_image_features
# Fill in the placeholder for the image
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy()
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
pt = 0
for i in range(bs):
if not need_vision[i]:
......@@ -379,10 +379,10 @@ class LlavaBaseForCausalLM(nn.Module):
pt += 1
return self.language_model(
input_ids, positions, input_metadata, input_embeds=input_embeds
input_ids, positions, forward_batch, input_embeds=input_embeds
)
elif input_metadata.forward_mode.is_decode():
return self.language_model(input_ids, positions, input_metadata)
elif forward_batch.forward_mode.is_decode():
return self.language_model(input_ids, positions, forward_batch)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Load clip vision model by cfg['mm_vision_tower']:
......
......@@ -27,7 +27,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.schedule_batch import ImageInputs
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.models.llama import LlamaForCausalLM
......@@ -108,11 +108,11 @@ class LlavaVidForCausalLM(nn.Module):
self,
input_ids: torch.LongTensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
image_inputs = input_metadata.image_inputs
if input_metadata.forward_mode.is_extend():
bs = input_metadata.batch_size
image_inputs = forward_batch.image_inputs
if forward_batch.forward_mode.is_extend():
bs = forward_batch.batch_size
# Embed text inputs
input_embeds = self.language_model.model.embed_tokens(input_ids)
......@@ -124,7 +124,7 @@ class LlavaVidForCausalLM(nn.Module):
max_image_offset.append(max(im.image_offsets))
else:
max_image_offset.append(-1)
start_positions = positions[input_metadata.extend_start_loc].cpu().numpy()
start_positions = positions[forward_batch.extend_start_loc].cpu().numpy()
need_vision = start_positions <= np.array(max_image_offset)
if need_vision.any():
......@@ -169,8 +169,8 @@ class LlavaVidForCausalLM(nn.Module):
image_features = new_image_features
# Fill in the placeholder for the image
extend_start_loc_cpu = input_metadata.extend_start_loc.cpu().numpy()
prefix_lens_cpu = input_metadata.extend_prefix_lens.cpu().numpy()
extend_start_loc_cpu = forward_batch.extend_start_loc.cpu().numpy()
prefix_lens_cpu = forward_batch.extend_prefix_lens.cpu().numpy()
pt = 0
for i in range(bs):
if not need_vision[i]:
......@@ -200,10 +200,10 @@ class LlavaVidForCausalLM(nn.Module):
pt += 1
return self.language_model(
input_ids, positions, input_metadata, input_embeds=input_embeds
input_ids, positions, forward_batch, input_embeds=input_embeds
)
elif input_metadata.forward_mode.is_decode():
return self.language_model(input_ids, positions, input_metadata)
elif forward_batch.forward_mode.is_decode():
return self.language_model(input_ids, positions, forward_batch)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
# Load clip vision model by cfg['mm_vision_tower']:
......
......@@ -39,7 +39,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class MiniCPMMLP(nn.Module):
......@@ -148,7 +148,7 @@ class MiniCPMAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
......@@ -156,7 +156,7 @@ class MiniCPMAttention(nn.Module):
q, k = q.float(), k.float()
q, k = self.rotary_emb(positions, q, k)
q, k = q.to(orig_dtype), k.to(orig_dtype)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
......@@ -199,7 +199,7 @@ class MiniCPMDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
......@@ -208,7 +208,7 @@ class MiniCPMDecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
hidden_states = residual + hidden_states * (
self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)
......@@ -252,7 +252,7 @@ class MiniCPMModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
......@@ -266,7 +266,7 @@ class MiniCPMModel(nn.Module):
hidden_states, residual = layer(
positions,
hidden_states,
input_metadata,
forward_batch,
residual,
)
hidden_states = self.norm(hidden_states)
......@@ -303,19 +303,19 @@ class MiniCPMForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is not None:
input_embeds = input_embeds * self.config.scale_emb
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
hidden_states = hidden_states / self.scale_width
if self.config.tie_word_embeddings:
lm_head_weight = self.model.embed_tokens.weight
else:
lm_head_weight = self.lm_head.weight
return self.logits_processor(
input_ids, hidden_states, lm_head_weight, input_metadata
input_ids, hidden_states, lm_head_weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
......@@ -42,7 +42,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
from sglang.srt.utils import is_hip
# ROCm: flashinfer available later
......@@ -193,7 +193,7 @@ class MiniCPM3Attention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
if self.q_lora_rank is not None:
q = self.q_a_proj(hidden_states)[0]
......@@ -230,7 +230,7 @@ class MiniCPM3Attention(nn.Module):
v = torch.nn.functional.pad(v, [0, 128 - self.v_head_dim], value=0).view(
-1, self.num_local_heads * 128
)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, 128)[
..., : self.v_head_dim
].reshape(-1, self.num_local_heads * self.v_head_dim)
......@@ -341,7 +341,7 @@ class MiniCPM3AttentionMLA(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
q_len = hidden_states.shape[0]
q_input = hidden_states.new_empty(
......@@ -383,7 +383,7 @@ class MiniCPM3AttentionMLA(nn.Module):
q_input[..., self.kv_lora_rank :] = q_pe
k_input[..., self.kv_lora_rank :] = k_pe
attn_output = self.attn(q_input, k_input, v_input, input_metadata)
attn_output = self.attn(q_input, k_input, v_input, forward_batch)
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
if self.w_vc.dtype == torch.float8_e4m3fn:
......@@ -472,7 +472,7 @@ class MiniCPM3DecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
......@@ -481,7 +481,7 @@ class MiniCPM3DecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
hidden_states = residual + hidden_states * (
self.config.scale_depth / math.sqrt(self.config.num_hidden_layers)
......@@ -528,7 +528,7 @@ class MiniCPM3Model(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
......@@ -542,7 +542,7 @@ class MiniCPM3Model(nn.Module):
hidden_states, residual = layer(
positions,
hidden_states,
input_metadata,
forward_batch,
residual,
)
hidden_states = self.norm(hidden_states)
......@@ -581,19 +581,19 @@ class MiniCPM3ForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is not None:
input_embeds = input_embeds * self.config.scale_emb
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
hidden_states = hidden_states / self.scale_width
if self.config.tie_word_embeddings:
lm_head_weight = self.model.embed_tokens.weight
else:
lm_head_weight = self.lm_head.weight
return self.logits_processor(
input_ids, hidden_states, lm_head_weight, input_metadata
input_ids, hidden_states, lm_head_weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
......@@ -43,7 +43,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class MixtralMoE(nn.Module):
......@@ -171,12 +171,12 @@ class MixtralAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
......@@ -220,7 +220,7 @@ class MixtralDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
......@@ -232,7 +232,7 @@ class MixtralDecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
# Fully Connected
......@@ -270,7 +270,7 @@ class MixtralModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
......@@ -281,7 +281,7 @@ class MixtralModel(nn.Module):
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, input_metadata, residual
positions, hidden_states, forward_batch, residual
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
......@@ -307,12 +307,12 @@ class MixtralForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
......@@ -45,7 +45,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class MixtralMLP(nn.Module):
......@@ -216,12 +216,12 @@ class MixtralAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
......@@ -256,7 +256,7 @@ class MixtralDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
......@@ -268,7 +268,7 @@ class MixtralDecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
# Fully Connected
......@@ -303,7 +303,7 @@ class MixtralModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
......@@ -314,7 +314,7 @@ class MixtralModel(nn.Module):
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, input_metadata, residual
positions, hidden_states, forward_batch, residual
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
......@@ -339,12 +339,12 @@ class QuantMixtralForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
......@@ -48,7 +48,7 @@ from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class OlmoeMoE(nn.Module):
......@@ -175,13 +175,13 @@ class OlmoeAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.q_norm(q.contiguous()), self.k_norm(k.contiguous())
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
......@@ -225,7 +225,7 @@ class OlmoeDecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> torch.Tensor:
# Self Attention
......@@ -238,7 +238,7 @@ class OlmoeDecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
# Fully Connected
......@@ -274,7 +274,7 @@ class OlmoeModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
......@@ -285,7 +285,7 @@ class OlmoeModel(nn.Module):
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, input_metadata, residual
positions, hidden_states, forward_batch, residual
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
......@@ -314,12 +314,12 @@ class OlmoeForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
......@@ -39,7 +39,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class QWenMLP(nn.Module):
......@@ -133,12 +133,12 @@ class QWenAttention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.c_attn(hidden_states)
q, k, v = qkv.chunk(chunks=3, dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.c_proj(attn_output)
return output
......@@ -177,7 +177,7 @@ class QWenBlock(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
# Self Attention
residual = hidden_states
......@@ -185,7 +185,7 @@ class QWenBlock(nn.Module):
hidden_states = self.attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
hidden_states = residual + hidden_states
......@@ -224,7 +224,7 @@ class QWenModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
hidden_states = self.wte(input_ids)
for i in range(len(self.h)):
......@@ -232,7 +232,7 @@ class QWenModel(nn.Module):
hidden_states = layer(
positions,
hidden_states,
input_metadata,
forward_batch,
)
hidden_states = self.ln_f(hidden_states)
return hidden_states
......@@ -257,11 +257,11 @@ class QWenLMHeadModel(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
):
hidden_states = self.transformer(input_ids, positions, input_metadata)
hidden_states = self.transformer(input_ids, positions, forward_batch)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
......@@ -40,7 +40,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
Qwen2Config = None
......@@ -149,12 +149,12 @@ class Qwen2Attention(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata)
attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output)
return output
......@@ -196,7 +196,7 @@ class Qwen2DecoderLayer(nn.Module):
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention
......@@ -208,7 +208,7 @@ class Qwen2DecoderLayer(nn.Module):
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata,
forward_batch=forward_batch,
)
# Fully Connected
......@@ -243,7 +243,7 @@ class Qwen2Model(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
......@@ -256,7 +256,7 @@ class Qwen2Model(nn.Module):
hidden_states, residual = layer(
positions,
hidden_states,
input_metadata,
forward_batch,
residual,
)
hidden_states, _ = self.norm(hidden_states, residual)
......@@ -283,17 +283,17 @@ class Qwen2ForCausalLM(nn.Module):
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
get_embedding: bool = False,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
if not get_embedding:
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
input_ids, hidden_states, self.lm_head.weight, forward_batch
)
else:
return self.pooler(hidden_states, input_metadata)
return self.pooler(hidden_states, forward_batch)
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
stacked_params_mapping = [
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment