Unverified Commit 564a898a authored by Liangsheng Yin's avatar Liangsheng Yin Committed by GitHub
Browse files

Optimize mem indices mangement (#619)

parent 5d264a90
......@@ -17,7 +17,8 @@ def run_one_batch_size(bs):
if args.input_len:
input_ids = [
[int(x) for x in np.random.randint(0, high=16384, size=(args.input_len,))] for _ in range(bs)
[int(x) for x in np.random.randint(0, high=16384, size=(args.input_len,))]
for _ in range(bs)
]
else:
text = [f"{i, }" for i in range(bs)]
......@@ -116,9 +117,11 @@ if __name__ == "__main__":
parser.add_argument("--port", type=int, default=None)
parser.add_argument("--backend", type=str, default="srt")
parser.add_argument("--input-len", type=int, default=None)
parser.add_argument("--batch-size", type=int, nargs='*', default=[1])
parser.add_argument("--batch-size", type=int, nargs="*", default=[1])
parser.add_argument("--max-tokens", type=int, default=256)
parser.add_argument("--vllm-model-name", type=str, default="meta-llama/Meta-Llama-3-70B")
parser.add_argument(
"--vllm-model-name", type=str, default="meta-llama/Meta-Llama-3-70B"
)
args = parser.parse_args()
if args.port is None:
......
......@@ -12,7 +12,6 @@ from sglang.utils import http_request
class RuntimeEndpoint(BaseBackend):
def __init__(
self,
base_url: str,
......@@ -38,7 +37,8 @@ class RuntimeEndpoint(BaseBackend):
self.model_info = res.json()
self.chat_template = get_chat_template_by_model_path(
self.model_info["model_path"])
self.model_info["model_path"]
)
def get_model_name(self):
return self.model_info["model_path"]
......@@ -124,7 +124,12 @@ class RuntimeEndpoint(BaseBackend):
else:
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]:
for item in [
"return_logprob",
"logprob_start_len",
"top_logprobs_num",
"return_text_in_logprobs",
]:
value = getattr(sampling_params, item, None)
if value is not None:
data[item] = value
......@@ -171,7 +176,12 @@ class RuntimeEndpoint(BaseBackend):
else:
raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
for item in ["return_logprob", "logprob_start_len", "top_logprobs_num", "return_text_in_logprobs"]:
for item in [
"return_logprob",
"logprob_start_len",
"top_logprobs_num",
"return_text_in_logprobs",
]:
value = getattr(sampling_params, item, None)
if value is not None:
data[item] = value
......
......@@ -32,7 +32,6 @@ import logging
import multiprocessing
import time
import numpy as np
import torch
import torch.distributed as dist
......
......@@ -44,4 +44,5 @@ class GlobalConfig:
# adjust_cache: Adjust the position embedding of KV cache.
self.concate_and_append_mode = "no_adjust"
global_config = GlobalConfig()
......@@ -84,7 +84,7 @@ register_chat_template(
"system": ("SYSTEM:", "\n"),
"user": ("USER:", "\n"),
"assistant": ("ASSISTANT:", "\n"),
}
},
)
)
......@@ -177,7 +177,7 @@ register_chat_template(
"assistant": ("", "<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",)
stop_str=("<|im_end|>",),
)
)
......
......@@ -24,9 +24,9 @@ class SglSamplingParams:
presence_penalty: float = 0.0
ignore_eos: bool = False
return_logprob: Optional[bool] = None
logprob_start_len: Optional[int] = None,
top_logprobs_num: Optional[int] = None,
return_text_in_logprobs: Optional[bool] = None,
logprob_start_len: Optional[int] = (None,)
top_logprobs_num: Optional[int] = (None,)
return_text_in_logprobs: Optional[bool] = (None,)
# for constrained generation, not included in to_xxx_kwargs
dtype: Optional[str] = None
......
......@@ -8,7 +8,10 @@ from vllm.distributed.parallel_state import graph_capture
from sglang.global_config import global_config
from sglang.srt.layers.logits_processor import LogitProcessorOutput
from sglang.srt.managers.controller.infer_batch import (
Batch, ForwardMode, InputMetadata, init_flashinfer_args
Batch,
ForwardMode,
InputMetadata,
init_flashinfer_args,
)
......@@ -24,18 +27,28 @@ class CudaGraphRunner:
# Common inputs
self.max_bs = max_batch_size_to_capture
self.input_ids = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
self.req_pool_indices = torch.zeros(
(self.max_bs,), dtype=torch.int32, device="cuda"
)
self.seq_lens = torch.ones((self.max_bs,), dtype=torch.int32, device="cuda")
self.position_ids_offsets = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32, device="cuda")
self.position_ids_offsets = torch.zeros(
(self.max_bs,), dtype=torch.int32, device="cuda"
)
self.out_cache_loc = torch.zeros(
(self.max_bs,), dtype=torch.int32, device="cuda"
)
# FlashInfer inputs
self.flashinfer_workspace_buffer = self.model_runner.flashinfer_workspace_buffers[0]
self.flashinfer_workspace_buffer = (
self.model_runner.flashinfer_workspace_buffers[0]
)
self.flashinfer_kv_indptr = torch.zeros(
(self.max_bs + 1,), dtype=torch.int32, device="cuda"
)
self.flashinfer_kv_indices = torch.zeros(
(self.max_bs * model_runner.model_config.context_len,), dtype=torch.int32, device="cuda"
(self.max_bs * model_runner.model_config.context_len,),
dtype=torch.int32,
device="cuda",
)
self.flashinfer_kv_last_page_len = torch.ones(
(self.max_bs,), dtype=torch.int32, device="cuda"
......@@ -49,7 +62,12 @@ class CudaGraphRunner:
with graph_capture() as graph_capture_context:
self.stream = graph_capture_context.stream
for bs in batch_size_list:
graph, input_buffers, output_buffers, flashinfer_handler = self.capture_one_batch_size(bs)
(
graph,
input_buffers,
output_buffers,
flashinfer_handler,
) = self.capture_one_batch_size(bs)
self.graphs[bs] = graph
self.input_buffers[bs] = input_buffers
self.output_buffers[bs] = output_buffers
......@@ -71,17 +89,19 @@ class CudaGraphRunner:
# FlashInfer inputs
if not _grouped_size_compiled_for_decode_kernels(
self.model_runner.model_config.num_attention_heads // self.model_runner.tp_size,
self.model_runner.model_config.num_attention_heads
// self.model_runner.tp_size,
self.model_runner.model_config.get_num_kv_heads(self.model_runner.tp_size),
):
use_tensor_cores = True
else:
use_tensor_cores = False
flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffer, "NHD",
self.flashinfer_workspace_buffer,
"NHD",
use_cuda_graph=True,
use_tensor_cores=use_tensor_cores,
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[:bs+1],
paged_kv_indptr_buffer=self.flashinfer_kv_indptr[: bs + 1],
paged_kv_indices_buffer=self.flashinfer_kv_indices,
paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
)
......@@ -163,10 +183,14 @@ class CudaGraphRunner:
else:
output = LogitProcessorOutput(
next_token_logits=output.next_token_logits[:raw_bs],
next_token_logprobs=output.next_token_logprobs[:raw_bs] if output.next_token_logprobs is not None else None,
next_token_logprobs=output.next_token_logprobs[:raw_bs]
if output.next_token_logprobs is not None
else None,
normalized_prompt_logprobs=None,
prefill_token_logprobs=None,
prefill_top_logprobs=None,
decode_top_logprobs=output.decode_top_logprobs[:raw_bs] if output.decode_top_logprobs is not None else None,
decode_top_logprobs=output.decode_top_logprobs[:raw_bs]
if output.decode_top_logprobs is not None
else None,
)
return output
......@@ -668,7 +668,9 @@ class Batch:
sampled_index = torch.multinomial(probs_sort, num_samples=1)
except RuntimeError as e:
warnings.warn(f"Ignore errors in sampling: {e}")
sampled_index = torch.ones(probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device)
sampled_index = torch.ones(
probs_sort.shape[:-1] + (1,), dtype=torch.int64, device=probs.device
)
batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(
-1
)
......@@ -749,8 +751,14 @@ class InputMetadata:
skip_flashinfer_init=False,
):
if not skip_flashinfer_init and not model_runner.server_args.disable_flashinfer:
init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens,
model_runner.flashinfer_decode_wrapper)
init_flashinfer_args(
forward_mode,
model_runner,
req_pool_indices,
seq_lens,
prefix_lens,
model_runner.flashinfer_decode_wrapper,
)
batch_size = len(req_pool_indices)
......@@ -807,16 +815,24 @@ class InputMetadata:
)
if model_runner.server_args.disable_flashinfer:
(ret.triton_max_seq_len,
ret.triton_max_extend_len,
ret.triton_start_loc,
ret.triton_prefix_lens) = init_triton_args(forward_mode, seq_lens, prefix_lens)
(
ret.triton_max_seq_len,
ret.triton_max_extend_len,
ret.triton_start_loc,
ret.triton_prefix_lens,
) = init_triton_args(forward_mode, seq_lens, prefix_lens)
return ret
def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens, prefix_lens,
flashinfer_decode_wrapper):
def init_flashinfer_args(
forward_mode,
model_runner,
req_pool_indices,
seq_lens,
prefix_lens,
flashinfer_decode_wrapper,
):
num_qo_heads = model_runner.model_config.num_attention_heads // model_runner.tp_size
num_kv_heads = model_runner.model_config.get_num_kv_heads(model_runner.tp_size)
head_dim = model_runner.model_config.head_dim
......@@ -827,9 +843,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
else:
paged_kernel_lens = prefix_lens
kv_indptr = torch.zeros(
(batch_size + 1,), dtype=torch.int32, device="cuda"
)
kv_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
kv_indptr[1:] = torch.cumsum(paged_kernel_lens, dim=0)
req_pool_indices_cpu = req_pool_indices.cpu().numpy()
paged_kernel_lens_cpu = paged_kernel_lens.cpu().numpy()
......@@ -842,9 +856,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
],
dim=0,
).contiguous()
kv_last_page_len = torch.ones(
(batch_size,), dtype=torch.int32, device="cuda"
)
kv_last_page_len = torch.ones((batch_size,), dtype=torch.int32, device="cuda")
if forward_mode == ForwardMode.DECODE:
flashinfer_decode_wrapper.end_forward()
......@@ -859,9 +871,7 @@ def init_flashinfer_args(forward_mode, model_runner, req_pool_indices, seq_lens,
)
else:
# extend part
qo_indptr = torch.zeros(
(batch_size + 1,), dtype=torch.int32, device="cuda"
)
qo_indptr = torch.zeros((batch_size + 1,), dtype=torch.int32, device="cuda")
qo_indptr[1:] = torch.cumsum(seq_lens - prefix_lens, dim=0)
model_runner.flashinfer_prefill_wrapper_ragged.end_forward()
......
......@@ -16,7 +16,12 @@ from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import ModelRegistry
from sglang.global_config import global_config
from sglang.srt.managers.controller.infer_batch import Batch, ForwardMode, InputMetadata, global_server_args_dict
from sglang.srt.managers.controller.infer_batch import (
Batch,
ForwardMode,
InputMetadata,
global_server_args_dict,
)
from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
from sglang.srt.server_args import ServerArgs
from sglang.srt.utils import (
......@@ -83,7 +88,9 @@ class ModelRunner:
# Set some global args
global_server_args_dict["disable_flashinfer"] = server_args.disable_flashinfer
global_server_args_dict["attention_reduce_in_fp32"] = server_args.attention_reduce_in_fp32
global_server_args_dict[
"attention_reduce_in_fp32"
] = server_args.attention_reduce_in_fp32
# Load the model and create memory pool
self.load_model()
......@@ -217,7 +224,9 @@ class ModelRunner:
self.flashinfer_workspace_buffers[1], "NHD"
)
self.flashinfer_decode_wrapper = BatchDecodeWithPagedKVCacheWrapper(
self.flashinfer_workspace_buffers[0], "NHD", use_tensor_cores=use_tensor_cores
self.flashinfer_workspace_buffers[0],
"NHD",
use_tensor_cores=use_tensor_cores,
)
def init_cuda_graphs(self):
......@@ -229,7 +238,9 @@ class ModelRunner:
logger.info(f"[gpu_id={self.gpu_id}] Capture cuda graph begin.")
batch_size_list = [1, 2, 4] + [i * 8 for i in range(1, 16)]
self.cuda_graph_runner = CudaGraphRunner(self, max_batch_size_to_capture=max(batch_size_list))
self.cuda_graph_runner = CudaGraphRunner(
self, max_batch_size_to_capture=max(batch_size_list)
)
self.cuda_graph_runner.capture(batch_size_list)
@torch.inference_mode()
......
......@@ -125,7 +125,8 @@ class RadixCache:
if x.lock_ref > 0:
continue
num_evicted += evict_callback(x.value)
evict_callback(x.value)
num_evicted += len(x.value)
self._delete_leaf(x)
if len(x.parent.children) == 0:
......
......@@ -314,7 +314,9 @@ class ModelTpServer:
self.forward_queue.append(req)
def get_new_fill_batch(self) -> Optional[Batch]:
running_bs = len(self.running_batch.reqs) if self.running_batch is not None else 0
running_bs = (
len(self.running_batch.reqs) if self.running_batch is not None else 0
)
if running_bs >= self.max_running_requests:
return
......
......@@ -39,10 +39,12 @@ class ReqToTokenPool:
class TokenToKVPool:
def __init__(self, size, dtype, head_num, head_dim, layer_num):
self.size = size
# mem_state is the reference counter.
# This can be promised:
# assert torch.all(mem_state <= 1) and torch.all(mem_state >= 0)
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self.mem_state = torch.zeros((self.size + 1,), dtype=torch.int16, device="cuda")
self.total_ref_ct = 0
self.mem_state = torch.zeros((self.size + 1,), dtype=torch.bool, device="cuda")
self.total_size = self.size
self.total_alloc = 0
# [size, key/value, head_num, head_dim] for each layer
self.kv_data = [
......@@ -71,7 +73,9 @@ class TokenToKVPool:
addition_size = need_size - buffer_len
alloc_size = max(addition_size, self.prefetch_chunk_size)
select_index = torch.nonzero(self.mem_state == 0).squeeze(1)[:alloc_size].to(torch.int32)
select_index = (
torch.nonzero(self.mem_state == 0).squeeze(1)[:alloc_size].to(torch.int32)
)
if select_index.shape[0] < addition_size:
return None
......@@ -105,26 +109,22 @@ class TokenToKVPool:
return select_index.to(torch.int32), start_loc, start_loc + need_size
def used_size(self):
return len(torch.nonzero(self.mem_state).squeeze(1))
return self.total_alloc
def available_size(self):
return torch.sum(self.mem_state == 0).item() + len(self.prefetch_buffer)
return self.total_size - self.total_alloc + len(self.prefetch_buffer)
def add_refs(self, token_index: torch.Tensor):
self.total_ref_ct += len(token_index)
self.mem_state[token_index] += 1
self.total_alloc += len(token_index)
self.mem_state[token_index] ^= True
def dec_refs(self, token_index: torch.Tensor):
self.total_ref_ct -= len(token_index)
self.mem_state[token_index] -= 1
num_freed = torch.sum(self.mem_state[token_index] == 0)
return num_freed
self.total_alloc -= len(token_index)
self.mem_state[token_index] ^= True
def clear(self):
self.mem_state.fill_(0)
self.total_ref_ct = 0
self.total_alloc = 0
# We also add one slot. This slot is used for writing dummy output from padded tokens.
self.add_refs(torch.tensor([0], dtype=torch.int32))
self.mem_state[0] = True
......@@ -5,12 +5,9 @@ from typing import Any, Dict, Iterable, Optional, Tuple
import torch
from torch import nn
from vllm.config import CacheConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
......@@ -31,7 +28,6 @@ from sglang.srt.managers.controller.model_runner import InputMetadata
class MiniCPMMLP(nn.Module):
def __init__(
self,
hidden_size: int,
......@@ -67,7 +63,6 @@ class MiniCPMMLP(nn.Module):
class MiniCPMAttention(nn.Module):
def __init__(
self,
hidden_size: int,
......@@ -152,7 +147,6 @@ class MiniCPMAttention(nn.Module):
class MiniCPMDecoderLayer(nn.Module):
def __init__(
self,
config,
......@@ -217,7 +211,6 @@ class MiniCPMDecoderLayer(nn.Module):
class MiniCPMModel(nn.Module):
def __init__(
self,
config,
......@@ -274,7 +267,7 @@ class MiniCPMForCausalLM(nn.Module):
) -> None:
super().__init__()
self.config = config
self.num_experts = getattr(self.config, "num_experts", 0)
self.quant_config = quant_config
self.model = MiniCPMModel(config, quant_config=quant_config)
......
......@@ -8,24 +8,28 @@ import torch
import torch.nn.functional as F
from torch import nn
from transformers import PretrainedConfig
from vllm.config import CacheConfig
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.distributed import (
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from vllm.model_executor.layers.activation import SiluAndMul
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear)
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.sampler import Sampler
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead, VocabParallelEmbedding)
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
from vllm.sequence import IntermediateTensors, SamplerOutput
......@@ -34,8 +38,8 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata
class Qwen2MoeMLP(nn.Module):
class Qwen2MoeMLP(nn.Module):
def __init__(
self,
hidden_size: int,
......@@ -46,17 +50,20 @@ class Qwen2MoeMLP(nn.Module):
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
)
self.down_proj = RowParallelLinear(
intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config,
reduce_results=reduce_results)
quant_config=quant_config,
reduce_results=reduce_results,
)
if hidden_act != "silu":
raise ValueError(f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now.")
raise ValueError(
f"Unsupported activation: {hidden_act}. "
"Only silu is supported for now."
)
self.act_fn = SiluAndMul()
def forward(self, x):
......@@ -67,7 +74,6 @@ class Qwen2MoeMLP(nn.Module):
class Qwen2MoeSparseMoeBlock(nn.Module):
def __init__(
self,
config: PretrainedConfig,
......@@ -79,20 +85,22 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
if self.tp_size > config.num_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_experts}.")
self.experts = FusedMoE(num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config)
self.gate = ReplicatedLinear(config.hidden_size,
config.num_experts,
bias=False,
quant_config=None)
f"the number of experts {config.num_experts}."
)
self.experts = FusedMoE(
num_experts=config.num_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
reduce_results=False,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
)
self.gate = ReplicatedLinear(
config.hidden_size, config.num_experts, bias=False, quant_config=None
)
if config.shared_expert_intermediate_size > 0:
self.shared_expert = Qwen2MoeMLP(
hidden_size=config.hidden_size,
......@@ -103,9 +111,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
)
else:
self.shared_expert = None
self.shared_expert_gate = torch.nn.Linear(config.hidden_size,
1,
bias=False)
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
......@@ -114,24 +120,24 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
if self.shared_expert is not None:
shared_output = self.shared_expert(hidden_states)
if self.shared_expert_gate is not None:
shared_output = F.sigmoid(
self.shared_expert_gate(hidden_states)) * shared_output
shared_output = (
F.sigmoid(self.shared_expert_gate(hidden_states)) * shared_output
)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
final_hidden_states = self.experts(
hidden_states=hidden_states, router_logits=router_logits
)
if shared_output is not None:
final_hidden_states = final_hidden_states + shared_output
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
final_hidden_states)
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
return final_hidden_states.view(num_tokens, hidden_dim)
class Qwen2MoeAttention(nn.Module):
def __init__(
self,
hidden_size: int,
......@@ -190,17 +196,19 @@ class Qwen2MoeAttention(nn.Module):
base=rope_theta,
rope_scaling=rope_scaling,
)
self.attn = RadixAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata
input_metadata: InputMetadata,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
......@@ -211,7 +219,6 @@ class Qwen2MoeAttention(nn.Module):
class Qwen2MoeDecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
......@@ -223,8 +230,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings",
8192)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
self.self_attn = Qwen2MoeAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
......@@ -239,13 +245,13 @@ class Qwen2MoeDecoderLayer(nn.Module):
# Note: Qwen/Qwen2-57B-A14B-Instruct does not have
# `mlp_only_layers` in the config.
mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else
config.mlp_only_layers)
mlp_only_layers = (
[] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers
)
if (layer_id not in mlp_only_layers) and (
config.num_experts > 0 and
(layer_id + 1) % config.decoder_sparse_step == 0):
self.mlp = Qwen2MoeSparseMoeBlock(config=config,
quant_config=quant_config)
config.num_experts > 0 and (layer_id + 1) % config.decoder_sparse_step == 0
):
self.mlp = Qwen2MoeSparseMoeBlock(config=config, quant_config=quant_config)
else:
self.mlp = Qwen2MoeMLP(
hidden_size=config.hidden_size,
......@@ -253,10 +259,10 @@ class Qwen2MoeDecoderLayer(nn.Module):
hidden_act=config.hidden_act,
quant_config=quant_config,
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
......@@ -270,23 +276,20 @@ class Qwen2MoeDecoderLayer(nn.Module):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
input_metadata=input_metadata
input_metadata=input_metadata,
)
# Fully Connected
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual)
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
hidden_states = self.mlp(hidden_states)
return hidden_states, residual
class Qwen2MoeModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
......@@ -301,13 +304,14 @@ class Qwen2MoeModel(nn.Module):
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
Qwen2MoeDecoderLayer(config,
layer_id,
cache_config,
quant_config=quant_config)
for layer_id in range(config.num_hidden_layers)
])
self.layers = nn.ModuleList(
[
Qwen2MoeDecoderLayer(
config, layer_id, cache_config, quant_config=quant_config
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
def forward(
......@@ -315,7 +319,7 @@ class Qwen2MoeModel(nn.Module):
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
input_embeds: torch.Tensor = None
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
......@@ -324,10 +328,9 @@ class Qwen2MoeModel(nn.Module):
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
hidden_states, residual = layer(positions,
hidden_states,
input_metadata,
residual)
hidden_states, residual = layer(
positions, hidden_states, input_metadata, residual
)
hidden_states, _ = self.norm(hidden_states, residual)
return hidden_states
......@@ -346,9 +349,9 @@ class Qwen2MoeForCausalLM(nn.Module):
self.config = config
self.quant_config = quant_config
self.model = Qwen2MoeModel(config, cache_config, quant_config)
self.lm_head = ParallelLMHead(config.vocab_size,
config.hidden_size,
quant_config=quant_config)
self.lm_head = ParallelLMHead(
config.vocab_size, config.hidden_size, quant_config=quant_config
)
self.logits_processor = LogitsProcessor(config)
self.sampler = Sampler()
......@@ -357,17 +360,22 @@ class Qwen2MoeForCausalLM(nn.Module):
input_ids: torch.Tensor,
positions: torch.Tensor,
input_metadata: InputMetadata,
input_embeds: torch.Tensor = None
input_embeds: torch.Tensor = None,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata, input_embeds)
return self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
def compute_logits(
self,
input_ids: torch.Tensor,
hidden_states: torch.Tensor,
input_metadata: InputMetadata,
) -> torch.Tensor:
hidden_states = self.model(input_ids, positions, input_metadata,
input_embeds)
return self.logits_processor(input_ids, hidden_states, self.lm_head.weight,
input_metadata)
def compute_logits(self, input_ids: torch.Tensor, hidden_states: torch.Tensor,
input_metadata: InputMetadata) -> torch.Tensor:
logits = self.logits_processor(input_ids, hidden_states, self.lm_head.weight,
input_metadata)
logits = self.logits_processor(
input_ids, hidden_states, self.lm_head.weight, input_metadata
)
return logits
def sample(
......@@ -391,11 +399,18 @@ class Qwen2MoeForCausalLM(nn.Module):
expert_params_mapping = [
# These are the weights for the experts
# (param_name, weight_name, expert_id, shard_id)
("experts.w13_weight" if weight_name in ["gate_proj", "up_proj"]
else "experts.w2_weight",
f"experts.{expert_id}.{weight_name}.weight", expert_id, shard_id)
for expert_id in range(self.config.num_experts) for shard_id,
weight_name in enumerate(["gate_proj", "down_proj", "up_proj"])
(
"experts.w13_weight"
if weight_name in ["gate_proj", "up_proj"]
else "experts.w2_weight",
f"experts.{expert_id}.{weight_name}.weight",
expert_id,
shard_id,
)
for expert_id in range(self.config.num_experts)
for shard_id, weight_name in enumerate(
["gate_proj", "down_proj", "up_proj"]
)
]
params_dict = dict(self.named_parameters())
......@@ -433,11 +448,13 @@ class Qwen2MoeForCausalLM(nn.Module):
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param,
loaded_weight,
weight_name,
shard_id=shard_id,
expert_id=expert_id)
weight_loader(
param,
loaded_weight,
weight_name,
shard_id=shard_id,
expert_id=expert_id,
)
break
else:
# Skip loading extra bias for GPTQ models.
......@@ -447,8 +464,10 @@ class Qwen2MoeForCausalLM(nn.Module):
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
EntryClass = Qwen2MoeForCausalLM
......@@ -474,9 +474,9 @@ def monkey_patch_vllm_dummy_weight_loader():
DummyModelLoader,
LoRAConfig,
ModelConfig,
MultiModalConfig,
ParallelConfig,
SchedulerConfig,
MultiModalConfig,
_initialize_model,
initialize_dummy_weights,
nn,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment