Commit 65501a9c authored by Lianmin Zheng's avatar Lianmin Zheng
Browse files

Fix commandr import; format code

parent db611066
...@@ -510,9 +510,13 @@ class ModelRpcServer: ...@@ -510,9 +510,13 @@ class ModelRpcServer:
batch.prepare_for_decode() batch.prepare_for_decode()
# Forward # Forward
logits, (_, _, decode_top_logprobs, _, last_logprobs) = ( logits, (
self.model_runner.forward(batch, ForwardMode.DECODE) _,
) _,
decode_top_logprobs,
_,
last_logprobs,
) = self.model_runner.forward(batch, ForwardMode.DECODE)
next_token_ids, _ = batch.sample(logits) next_token_ids, _ = batch.sample(logits)
next_token_ids = next_token_ids.cpu().tolist() next_token_ids = next_token_ids.cpu().tolist()
......
...@@ -24,27 +24,30 @@ from typing import List, Optional, Tuple ...@@ -24,27 +24,30 @@ from typing import List, Optional, Tuple
import torch import torch
import torch.utils.checkpoint import torch.utils.checkpoint
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
from torch import nn from torch import nn
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from transformers import CohereConfig from transformers import PretrainedConfig
from vllm.model_executor.parallel_utils.parallel_state import (get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size)
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.linear import (LinearMethodBase, from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear, LinearMethodBase,
QKVParallelLinear, MergedColumnParallelLinear,
RowParallelLinear) QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
VocabParallelEmbedding) from vllm.model_executor.parallel_utils.parallel_state import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.weight_utils import (default_weight_loader, from vllm.model_executor.weight_utils import (
hf_model_weights_iterator) default_weight_loader,
hf_model_weights_iterator,
from sglang.srt.layers.logits_processor import LogitsProcessor )
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.router.model_runner import InputMetadata
@torch.compile @torch.compile
...@@ -53,14 +56,12 @@ def layer_norm_func(hidden_states, weight, variance_epsilon): ...@@ -53,14 +56,12 @@ def layer_norm_func(hidden_states, weight, variance_epsilon):
hidden_states = hidden_states.to(torch.float32) hidden_states = hidden_states.to(torch.float32)
mean = hidden_states.mean(-1, keepdim=True) mean = hidden_states.mean(-1, keepdim=True)
variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True) variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
hidden_states = (hidden_states - mean) * torch.rsqrt(variance + hidden_states = (hidden_states - mean) * torch.rsqrt(variance + variance_epsilon)
variance_epsilon)
hidden_states = weight.to(torch.float32) * hidden_states hidden_states = weight.to(torch.float32) * hidden_states
return hidden_states.to(input_dtype) return hidden_states.to(input_dtype)
class LayerNorm(nn.Module): class LayerNorm(nn.Module):
def __init__(self, param_shape=None, eps=1e-5): def __init__(self, param_shape=None, eps=1e-5):
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.ones(param_shape)) self.weight = nn.Parameter(torch.ones(param_shape))
...@@ -68,8 +69,9 @@ class LayerNorm(nn.Module): ...@@ -68,8 +69,9 @@ class LayerNorm(nn.Module):
set_weight_attrs(self.weight, {"weight_loader": self.weight_loader}) set_weight_attrs(self.weight, {"weight_loader": self.weight_loader})
def forward(self, hidden_states, residuals=None): def forward(self, hidden_states, residuals=None):
hidden_states = layer_norm_func(hidden_states, self.weight, hidden_states = layer_norm_func(
self.variance_epsilon) hidden_states, self.weight, self.variance_epsilon
)
return hidden_states, residuals return hidden_states, residuals
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
...@@ -79,15 +81,13 @@ class LayerNorm(nn.Module): ...@@ -79,15 +81,13 @@ class LayerNorm(nn.Module):
if shard_dim is not None: if shard_dim is not None:
shard_size = param_data.shape[shard_dim] shard_size = param_data.shape[shard_dim]
start_idx = tp_rank * shard_size start_idx = tp_rank * shard_size
loaded_weight = loaded_weight.narrow(shard_dim, start_idx, loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size)
shard_size)
assert param_data.shape == loaded_weight.shape assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight) param_data.copy_(loaded_weight)
# Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere # Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere
class CohereMLP(nn.Module): class CohereMLP(nn.Module):
def __init__( def __init__(
self, self,
config, config,
...@@ -119,10 +119,9 @@ class CohereMLP(nn.Module): ...@@ -119,10 +119,9 @@ class CohereMLP(nn.Module):
class CohereAttention(nn.Module): class CohereAttention(nn.Module):
def __init__( def __init__(
self, self,
config: CohereConfig, config: PretrainedConfig,
layer_id: int = 0, layer_id: int = 0,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
): ):
...@@ -148,8 +147,8 @@ class CohereAttention(nn.Module): ...@@ -148,8 +147,8 @@ class CohereAttention(nn.Module):
self.kv_size = self.num_kv_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5 self.scaling = self.head_dim**-0.5
self.max_position_embeddings = getattr( self.max_position_embeddings = getattr(
config, "model_max_length", None) or getattr( config, "model_max_length", None
config, "max_position_embeddings", 8192) ) or getattr(config, "max_position_embeddings", 8192)
self.rope_theta = config.rope_theta self.rope_theta = config.rope_theta
self.rope_scaling = getattr(config, "rope_scaling", None) self.rope_scaling = getattr(config, "rope_scaling", None)
self.use_qk_norm = getattr(config, "use_qk_norm", False) self.use_qk_norm = getattr(config, "use_qk_norm", False)
...@@ -183,12 +182,13 @@ class CohereAttention(nn.Module): ...@@ -183,12 +182,13 @@ class CohereAttention(nn.Module):
layer_id=layer_id, layer_id=layer_id,
) )
if self.use_qk_norm: if self.use_qk_norm:
self.q_norm = LayerNorm(param_shape=(self.num_heads, self.q_norm = LayerNorm(
self.head_dim), param_shape=(self.num_heads, self.head_dim), eps=config.layer_norm_eps
eps=config.layer_norm_eps) )
self.k_norm = LayerNorm(param_shape=(self.num_kv_heads, self.k_norm = LayerNorm(
self.head_dim), param_shape=(self.num_kv_heads, self.head_dim),
eps=config.layer_norm_eps) eps=config.layer_norm_eps,
)
def _apply_qk_norm(self, q, k): def _apply_qk_norm(self, q, k):
q = q.view(*q.shape[:-1], -1, self.head_dim) q = q.view(*q.shape[:-1], -1, self.head_dim)
...@@ -216,19 +216,23 @@ class CohereAttention(nn.Module): ...@@ -216,19 +216,23 @@ class CohereAttention(nn.Module):
class CohereDecoderLayer(nn.Module): class CohereDecoderLayer(nn.Module):
def __init__(
def __init__(self, self,
config: CohereConfig, config: PretrainedConfig,
layer_id: int = 0, layer_id: int = 0,
linear_method: Optional[LinearMethodBase] = None): linear_method: Optional[LinearMethodBase] = None,
):
super().__init__() super().__init__()
self.hidden_size = config.hidden_size self.hidden_size = config.hidden_size
self.self_attn = CohereAttention(config, layer_id=layer_id, linear_method=linear_method) self.self_attn = CohereAttention(
config, layer_id=layer_id, linear_method=linear_method
)
self.mlp = CohereMLP(config, linear_method=linear_method) self.mlp = CohereMLP(config, linear_method=linear_method)
self.input_layernorm = LayerNorm(param_shape=(config.hidden_size), self.input_layernorm = LayerNorm(
eps=config.layer_norm_eps) param_shape=(config.hidden_size), eps=config.layer_norm_eps
)
def forward( def forward(
self, self,
...@@ -253,23 +257,26 @@ class CohereDecoderLayer(nn.Module): ...@@ -253,23 +257,26 @@ class CohereDecoderLayer(nn.Module):
class CohereModel(nn.Module): class CohereModel(nn.Module):
def __init__( def __init__(
self, self,
config: CohereConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
): ):
super().__init__() super().__init__()
self.config = config self.config = config
self.vocab_size = config.vocab_size self.vocab_size = config.vocab_size
self.embed_tokens = VocabParallelEmbedding(config.vocab_size, self.embed_tokens = VocabParallelEmbedding(
config.hidden_size) config.vocab_size, config.hidden_size
self.layers = nn.ModuleList([ )
CohereDecoderLayer(config, i, linear_method=linear_method) self.layers = nn.ModuleList(
for i in range(config.num_hidden_layers) [
]) CohereDecoderLayer(config, i, linear_method=linear_method)
self.norm = LayerNorm(param_shape=(config.hidden_size), for i in range(config.num_hidden_layers)
eps=config.layer_norm_eps) ]
)
self.norm = LayerNorm(
param_shape=(config.hidden_size), eps=config.layer_norm_eps
)
def forward( def forward(
self, self,
...@@ -292,10 +299,9 @@ class CohereModel(nn.Module): ...@@ -292,10 +299,9 @@ class CohereModel(nn.Module):
class CohereForCausalLM(nn.Module): class CohereForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: CohereConfig, config: PretrainedConfig,
linear_method: Optional[LinearMethodBase] = None, linear_method: Optional[LinearMethodBase] = None,
) -> None: ) -> None:
super().__init__() super().__init__()
...@@ -311,7 +317,11 @@ class CohereForCausalLM(nn.Module): ...@@ -311,7 +317,11 @@ class CohereForCausalLM(nn.Module):
positions: torch.Tensor, positions: torch.Tensor,
input_metadata: InputMetadata, input_metadata: InputMetadata,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata,) hidden_states = self.model(
input_ids,
positions,
input_metadata,
)
return self.logits_processor( return self.logits_processor(
input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata input_ids, hidden_states, self.model.embed_tokens.weight, input_metadata
) )
...@@ -334,7 +344,8 @@ class CohereForCausalLM(nn.Module): ...@@ -334,7 +344,8 @@ class CohereForCausalLM(nn.Module):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params = set() loaded_params = set()
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision
):
for param_name, shard_name, shard_id in stacked_params_mapping: for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name: if shard_name not in name:
continue continue
...@@ -355,8 +366,7 @@ class CohereForCausalLM(nn.Module): ...@@ -355,8 +366,7 @@ class CohereForCausalLM(nn.Module):
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader", default_weight_loader)
default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
......
...@@ -171,7 +171,6 @@ class DbrxExperts(nn.Module): ...@@ -171,7 +171,6 @@ class DbrxExperts(nn.Module):
class DbrxAttention(nn.Module): class DbrxAttention(nn.Module):
def __init__( def __init__(
self, self,
config: DbrxConfig, config: DbrxConfig,
...@@ -251,7 +250,6 @@ class DbrxAttention(nn.Module): ...@@ -251,7 +250,6 @@ class DbrxAttention(nn.Module):
class DbrxFusedNormAttention(nn.Module): class DbrxFusedNormAttention(nn.Module):
def __init__( def __init__(
self, self,
config: DbrxConfig, config: DbrxConfig,
...@@ -284,7 +282,6 @@ class DbrxFusedNormAttention(nn.Module): ...@@ -284,7 +282,6 @@ class DbrxFusedNormAttention(nn.Module):
class DbrxBlock(nn.Module): class DbrxBlock(nn.Module):
def __init__( def __init__(
self, self,
config: DbrxConfig, config: DbrxConfig,
...@@ -312,7 +309,6 @@ class DbrxBlock(nn.Module): ...@@ -312,7 +309,6 @@ class DbrxBlock(nn.Module):
class DbrxModel(nn.Module): class DbrxModel(nn.Module):
def __init__( def __init__(
self, self,
config: DbrxConfig, config: DbrxConfig,
...@@ -351,7 +347,6 @@ class DbrxModel(nn.Module): ...@@ -351,7 +347,6 @@ class DbrxModel(nn.Module):
class DbrxForCausalLM(nn.Module): class DbrxForCausalLM(nn.Module):
def __init__( def __init__(
self, self,
config: DbrxConfig, config: DbrxConfig,
......
...@@ -66,9 +66,9 @@ class BenchBatch: ...@@ -66,9 +66,9 @@ class BenchBatch:
p_idx = prefix_req_idx[i // fork_num].item() p_idx = prefix_req_idx[i // fork_num].item()
n_idx = self.req_pool_indices[i].item() n_idx = self.req_pool_indices[i].item()
req_to_token[n_idx, :prefix_len] = req_to_token[p_idx, :prefix_len] req_to_token[n_idx, :prefix_len] = req_to_token[p_idx, :prefix_len]
req_to_token[n_idx, prefix_len : prefix_len + extend_len] = ( req_to_token[
self.out_cache_loc[i * extend_len : (i + 1) * extend_len] n_idx, prefix_len : prefix_len + extend_len
) ] = self.out_cache_loc[i * extend_len : (i + 1) * extend_len]
def update_decode(self, predict_ids, batch_size): def update_decode(self, predict_ids, batch_size):
assert predict_ids.shape[0] == batch_size assert predict_ids.shape[0] == batch_size
...@@ -81,9 +81,9 @@ class BenchBatch: ...@@ -81,9 +81,9 @@ class BenchBatch:
self.out_cache_cont_start, self.out_cache_cont_start,
self.out_cache_cont_end, self.out_cache_cont_end,
) = self.token_to_kv_pool.alloc_contiguous(batch_size) ) = self.token_to_kv_pool.alloc_contiguous(batch_size)
self.req_to_token_pool.req_to_token[self.req_pool_indices, self.seq_lens] = ( self.req_to_token_pool.req_to_token[
self.out_cache_loc self.req_pool_indices, self.seq_lens
) ] = self.out_cache_loc
self.seq_lens.add_(1) self.seq_lens.add_(1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment