Unverified Commit 36d5acfc authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Rename InputMetadata -> ForwardBatch (#1543)

parent 3f0fe08d
...@@ -49,7 +49,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig ...@@ -49,7 +49,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.torchao_utils import apply_torchao_config_ from sglang.srt.layers.torchao_utils import apply_torchao_config_
from sglang.srt.managers.schedule_batch import global_server_args_dict from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class Qwen2MoeMLP(nn.Module): class Qwen2MoeMLP(nn.Module):
...@@ -221,12 +221,12 @@ class Qwen2MoeAttention(nn.Module): ...@@ -221,12 +221,12 @@ class Qwen2MoeAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata) attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -281,7 +281,7 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -281,7 +281,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
...@@ -293,7 +293,7 @@ class Qwen2MoeDecoderLayer(nn.Module): ...@@ -293,7 +293,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
input_metadata=input_metadata, forward_batch=forward_batch,
) )
# Fully Connected # Fully Connected
...@@ -331,7 +331,7 @@ class Qwen2MoeModel(nn.Module): ...@@ -331,7 +331,7 @@ class Qwen2MoeModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
if input_embeds is None: if input_embeds is None:
...@@ -342,7 +342,7 @@ class Qwen2MoeModel(nn.Module): ...@@ -342,7 +342,7 @@ class Qwen2MoeModel(nn.Module):
for i in range(len(self.layers)): for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, hidden_states, input_metadata, residual positions, hidden_states, forward_batch, residual
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
...@@ -373,12 +373,12 @@ class Qwen2MoeForCausalLM(nn.Module): ...@@ -373,12 +373,12 @@ class Qwen2MoeForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -40,7 +40,7 @@ from sglang.srt.layers.linear import ( ...@@ -40,7 +40,7 @@ from sglang.srt.layers.linear import (
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class StablelmMLP(nn.Module): class StablelmMLP(nn.Module):
...@@ -145,12 +145,12 @@ class StablelmAttention(nn.Module): ...@@ -145,12 +145,12 @@ class StablelmAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata) attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -173,7 +173,7 @@ class StablelmDecoderLayer(nn.Module): ...@@ -173,7 +173,7 @@ class StablelmDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
residual = hidden_states residual = hidden_states
...@@ -181,7 +181,7 @@ class StablelmDecoderLayer(nn.Module): ...@@ -181,7 +181,7 @@ class StablelmDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
input_metadata=input_metadata, forward_batch=forward_batch,
) )
hidden_states = residual + hidden_states hidden_states = residual + hidden_states
...@@ -218,7 +218,7 @@ class StableLMEpochModel(nn.Module): ...@@ -218,7 +218,7 @@ class StableLMEpochModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
if input_embeds is None: if input_embeds is None:
...@@ -230,7 +230,7 @@ class StableLMEpochModel(nn.Module): ...@@ -230,7 +230,7 @@ class StableLMEpochModel(nn.Module):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
input_metadata, forward_batch,
) )
hidden_states = self.norm(hidden_states) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states
...@@ -255,12 +255,12 @@ class StableLmForCausalLM(nn.Module): ...@@ -255,12 +255,12 @@ class StableLmForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
...@@ -41,7 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -41,7 +41,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.model_runner import InputMetadata from sglang.srt.model_executor.model_runner import ForwardBatch
class XverseMLP(nn.Module): class XverseMLP(nn.Module):
...@@ -160,12 +160,12 @@ class XverseAttention(nn.Module): ...@@ -160,12 +160,12 @@ class XverseAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata) attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -222,7 +222,7 @@ class XverseDecoderLayer(nn.Module): ...@@ -222,7 +222,7 @@ class XverseDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
# Self Attention # Self Attention
...@@ -234,7 +234,7 @@ class XverseDecoderLayer(nn.Module): ...@@ -234,7 +234,7 @@ class XverseDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
input_metadata=input_metadata, forward_batch=forward_batch,
) )
# Fully Connected # Fully Connected
...@@ -271,7 +271,7 @@ class XverseModel(nn.Module): ...@@ -271,7 +271,7 @@ class XverseModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
if input_embeds is None: if input_embeds is None:
...@@ -284,7 +284,7 @@ class XverseModel(nn.Module): ...@@ -284,7 +284,7 @@ class XverseModel(nn.Module):
hidden_states, residual = layer( hidden_states, residual = layer(
positions, positions,
hidden_states, hidden_states,
input_metadata, forward_batch,
residual, residual,
) )
# print(f"layer[{i}].hidden_states: {hidden_states}") # print(f"layer[{i}].hidden_states: {hidden_states}")
...@@ -312,12 +312,12 @@ class XverseForCausalLM(nn.Module): ...@@ -312,12 +312,12 @@ class XverseForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None, input_embeds: torch.Tensor = None,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds) hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights( def load_weights(
......
...@@ -44,7 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader ...@@ -44,7 +44,7 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.radix_attention import RadixAttention from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.model_executor.forward_batch_info import InputMetadata from sglang.srt.model_executor.forward_batch_info import ForwardBatch
class XverseMLP(nn.Module): class XverseMLP(nn.Module):
...@@ -244,12 +244,12 @@ class XverseAttention(nn.Module): ...@@ -244,12 +244,12 @@ class XverseAttention(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states) qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, input_metadata) attn_output = self.attn(q, k, v, forward_batch)
output, _ = self.o_proj(attn_output) output, _ = self.o_proj(attn_output)
return output return output
...@@ -300,7 +300,7 @@ class XverseDecoderLayer(nn.Module): ...@@ -300,7 +300,7 @@ class XverseDecoderLayer(nn.Module):
self, self,
positions: torch.Tensor, positions: torch.Tensor,
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
residual: Optional[torch.Tensor], residual: Optional[torch.Tensor],
) -> torch.Tensor: ) -> torch.Tensor:
# Self Attention # Self Attention
...@@ -312,7 +312,7 @@ class XverseDecoderLayer(nn.Module): ...@@ -312,7 +312,7 @@ class XverseDecoderLayer(nn.Module):
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
input_metadata=input_metadata, forward_batch=forward_batch,
) )
# Fully Connected # Fully Connected
...@@ -353,14 +353,14 @@ class XverseModel(nn.Module): ...@@ -353,14 +353,14 @@ class XverseModel(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.embed_tokens(input_ids) hidden_states = self.embed_tokens(input_ids)
residual = None residual = None
for i in range(len(self.layers)): for i in range(len(self.layers)):
layer = self.layers[i] layer = self.layers[i]
hidden_states, residual = layer( hidden_states, residual = layer(
positions, hidden_states, input_metadata, residual positions, hidden_states, forward_batch, residual
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states return hidden_states
...@@ -388,11 +388,11 @@ class XverseMoeForCausalLM(nn.Module): ...@@ -388,11 +388,11 @@ class XverseMoeForCausalLM(nn.Module):
self, self,
input_ids: torch.Tensor, input_ids: torch.Tensor,
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, forward_batch: ForwardBatch,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata) hidden_states = self.model(input_ids, positions, forward_batch)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata input_ids, hidden_states, self.lm_head.weight, forward_batch
) )
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment