Unverified Commit 53cef815 authored by Lianmin Zheng's avatar Lianmin Zheng Committed by GitHub
Browse files

Improve weight loading and code style (#3174)

parent 351a72d4
......@@ -329,12 +329,14 @@ class ColumnParallelLinear(LinearBase):
prefix: str = "",
tp_rank: Optional[int] = None,
tp_size: Optional[int] = None,
use_presharded_weights: bool = False,
):
super().__init__(
input_size, output_size, skip_bias_add, params_dtype, quant_config, prefix
)
self.gather_output = gather_output
self.use_presharded_weights = use_presharded_weights
# Divide the weight matrix along the last dimension.
if tp_rank is None:
......@@ -402,7 +404,8 @@ class ColumnParallelLinear(LinearBase):
if output_dim is not None and not use_bitsandbytes_4bit:
shard_size = param_data.shape[output_dim]
start_idx = self.tp_rank * shard_size
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
if not self.use_presharded_weights:
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
# Special case for loading scales off disk, which often do not
# have a shape (such as in the case of AutoFP8).
......@@ -418,7 +421,11 @@ class ColumnParallelLinear(LinearBase):
if len(loaded_weight.shape) == 0:
assert loaded_weight.numel() == 1
loaded_weight = loaded_weight.reshape(1)
param.load_column_parallel_weight(loaded_weight, tp_rank=self.tp_rank)
param.load_column_parallel_weight(
loaded_weight,
tp_rank=self.tp_rank,
use_presharded_weights=self.use_presharded_weights,
)
def forward(self, input_):
bias = self.bias if not self.skip_bias_add else None
......@@ -499,7 +506,9 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
prefix=prefix,
tp_rank=tp_rank,
tp_size=tp_size,
use_presharded_weights=use_presharded_weights,
)
self.prefix = prefix
def weight_loader(
self,
......@@ -743,6 +752,7 @@ class QKVParallelLinear(ColumnParallelLinear):
prefix: str = "",
tp_rank: Optional[int] = None,
tp_size: Optional[int] = None,
load_presharded_attn: bool = False,
):
self.hidden_size = hidden_size
self.head_size = head_size
......@@ -772,6 +782,7 @@ class QKVParallelLinear(ColumnParallelLinear):
self.num_kv_heads * self.head_size * tp_size, # k_proj
self.num_kv_heads * self.head_size * tp_size, # v_proj
]
self.use_presharded_weights = load_presharded_attn
super().__init__(
input_size=input_size,
......@@ -784,6 +795,7 @@ class QKVParallelLinear(ColumnParallelLinear):
prefix=prefix,
tp_rank=tp_rank,
tp_size=tp_size,
use_presharded_weights=self.use_presharded_weights,
)
def _get_shard_offset_mapping(self, loaded_shard_id: str):
......@@ -842,9 +854,10 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size=shard_size, shard_offset=shard_offset
)
loaded_weight_shard = loaded_weight.narrow(
param.output_dim, shard_offset, shard_size
)
if not self.use_presharded_weights:
loaded_weight_shard = loaded_weight.narrow(
param.output_dim, shard_offset, shard_size
)
self.weight_loader_v2(param, loaded_weight_shard, shard_id)
def weight_loader_v2(
......@@ -882,6 +895,7 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_offset=shard_offset,
shard_size=shard_size,
tp_rank=self.tp_rank,
use_presharded_weights=self.use_presharded_weights,
)
def weight_loader(
......@@ -987,9 +1001,10 @@ class QKVParallelLinear(ColumnParallelLinear):
param, orig_qkv_offsets, shard_id
)
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size
)
if not self.use_presharded_weights:
loaded_weight_shard = loaded_weight.narrow(
output_dim, shard_offset, shard_size
)
self.weight_loader(param, loaded_weight_shard, shard_id)
return
......@@ -1049,7 +1064,7 @@ class QKVParallelLinear(ColumnParallelLinear):
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if not use_bitsandbytes_4bit:
if not use_bitsandbytes_4bit and not self.use_presharded_weights:
loaded_weight = loaded_weight.narrow(output_dim, start_idx, shard_size)
# Special case for for AQLM codebooks.
......
......@@ -114,6 +114,7 @@ class EPMoE(torch.nn.Module):
tp_size: Optional[int] = None,
prefix: str = "",
correction_bias: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
activation: str = "silu",
):
super().__init__()
......@@ -141,6 +142,7 @@ class EPMoE(torch.nn.Module):
self.num_expert_group = num_expert_group
self.topk_group = topk_group
self.correction_bias = correction_bias
self.custom_routing_function = custom_routing_function
self.activation = activation
if quant_config is None:
......@@ -184,6 +186,7 @@ class EPMoE(torch.nn.Module):
topk_group=self.topk_group,
num_expert_group=self.num_expert_group,
correction_bias=self.correction_bias,
custom_routing_function=self.custom_routing_function,
)
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
......@@ -257,16 +260,20 @@ class EPMoE(torch.nn.Module):
dtype=torch.float32,
device=hidden_states.device,
)
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
gateup_output,
down_input,
gateup_output.shape[1],
reorder_topk_ids,
self.w2_input_scale,
self.start_expert_id,
self.end_expert_id,
BLOCK_SIZE=512,
)
if self.activation == "silu":
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
gateup_output,
down_input,
gateup_output.shape[1],
reorder_topk_ids,
self.w2_input_scale,
self.start_expert_id,
self.end_expert_id,
BLOCK_SIZE=512,
)
else:
raise ValueError(f"Unsupported activation: {self.activation=}")
# GroupGemm-1
down_output = torch.empty(
......@@ -312,7 +319,6 @@ class EPMoE(torch.nn.Module):
ckpt_up_proj_name: str,
num_experts: int,
) -> List[Tuple[str, str, int, str]]:
return [
# (param_name, weight_name, expert_id, shard_id)
(
......@@ -357,7 +363,6 @@ class EPMoE(torch.nn.Module):
)
return
expert_data = param.data[expert_id]
if shard_id == "w2":
param.data[expert_id] = loaded_weight
elif shard_id == "w1":
......
......@@ -124,7 +124,13 @@ class _ColumnvLLMParameter(BasevLLMParameter):
assert param_data.shape == loaded_weight.shape
param_data.copy_(loaded_weight)
def load_qkv_weight(self, loaded_weight: torch.Tensor, tp_rank: int, **kwargs):
def load_qkv_weight(
self,
loaded_weight: torch.Tensor,
tp_rank: int,
use_presharded_weights: bool = False,
**kwargs,
):
shard_offset = kwargs.get("shard_offset")
shard_size = kwargs.get("shard_size")
......@@ -142,11 +148,14 @@ class _ColumnvLLMParameter(BasevLLMParameter):
param_data = self.data
shard_id = tp_rank if shard_id == "q" else tp_rank // num_heads
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
loaded_weight = loaded_weight.narrow(
self.output_dim, shard_id * shard_size, shard_size
)
if not use_presharded_weights:
loaded_weight = loaded_weight.narrow(
self.output_dim, shard_id * shard_size, shard_size
)
assert param_data.shape == loaded_weight.shape
assert (
param_data.shape == loaded_weight.shape
), f"{param_data.shape=}, {loaded_weight.shape=}"
param_data.copy_(loaded_weight)
......@@ -292,7 +301,7 @@ class PackedColumnParameter(_ColumnvLLMParameter):
packed_factor: Union[int, Fraction],
packed_dim: int,
marlin_tile_size: Optional[int] = None,
**kwargs
**kwargs,
):
self._packed_factor = packed_factor
self._packed_dim = packed_dim
......@@ -336,7 +345,7 @@ class PackedvLLMParameter(ModelWeightParameter):
packed_factor: Union[int, Fraction],
packed_dim: int,
marlin_tile_size: Optional[int] = None,
**kwargs
**kwargs,
):
self._packed_factor = packed_factor
self._packed_dim = packed_dim
......
......@@ -247,6 +247,7 @@ class Req:
# Each decode stage's output ids
self.output_ids = []
# fill_ids = origin_input_ids + output_ids. Updated if chunked.
self.fill_ids = None
self.session_id = session_id
self.input_embeds = input_embeds
......
......@@ -486,7 +486,7 @@ class Scheduler:
@torch.no_grad()
def event_loop_overlap(self):
"""A scheduler loop that overlaps the CPU processing and GPU computation."""
result_queue = deque()
self.result_queue = deque()
while True:
recv_reqs = self.recv_requests()
......@@ -497,7 +497,7 @@ class Scheduler:
if batch:
result = self.run_batch(batch)
result_queue.append((batch.copy(), result))
self.result_queue.append((batch.copy(), result))
if self.last_batch is None:
# Create a dummy first batch to start the pipeline for overlap schedule.
......@@ -511,7 +511,7 @@ class Scheduler:
if self.last_batch:
# Process the results of the last batch
tmp_batch, tmp_result = result_queue.popleft()
tmp_batch, tmp_result = self.result_queue.popleft()
tmp_batch.next_batch_sampling_info = (
self.tp_worker.cur_sampling_info if batch else None
)
......@@ -642,7 +642,7 @@ class Scheduler:
self.waiting_queue.append(req)
return
# Handle image inputs
# Handle multimodal inputs
if recv_req.image_inputs is not None:
image_inputs = ImageInputs.from_dict(recv_req.image_inputs)
# Expand a single image token into multiple dummy tokens for receiving image embeddings
......@@ -743,7 +743,13 @@ class Scheduler:
req.logprob_start_len = len(req.origin_input_ids) - 1
self.waiting_queue.append(req)
def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked):
def log_prefill_stats(
self,
adder: PrefillAdder,
can_run_list: List[Req],
running_bs: ScheduleBatch,
has_being_chunked: bool,
):
self.tree_cache_metrics["total"] += (
adder.log_input_tokens + adder.log_hit_tokens
) / 10**9
......
......@@ -218,7 +218,7 @@ class ModelRunner:
def init_torch_distributed(self):
logger.info("Init torch distributed begin.")
# Init torch distributed
torch.get_device_module(self.device).set_device(self.gpu_id)
if self.device == "cuda":
backend = "nccl"
......
......@@ -404,8 +404,13 @@ def np_cache_weights_iterator(
def safetensors_weights_iterator(
hf_weights_files: List[str],
is_all_weights_sharded: bool = False,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
"""Iterate over the weights in the model safetensor files.
If is_all_weights_sharded is True, it uses more optimize read by reading an
entire file instead of reading each tensor one by one.
"""
enable_tqdm = (
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
)
......@@ -415,9 +420,14 @@ def safetensors_weights_iterator(
disable=not enable_tqdm,
bar_format=_BAR_FORMAT,
):
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
if not is_all_weights_sharded:
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
yield name, param
else:
result = load_file(st_file, device="cpu")
for name, param in result.items():
yield name, param
......
......@@ -75,6 +75,7 @@ class ServerArgs:
# Other runtime options
tp_size: int = 1
stream_interval: int = 1
stream_output: bool = False
random_seed: Optional[int] = None
constrained_json_whitespace_pattern: Optional[str] = None
watchdog_timeout: float = 300
......@@ -500,6 +501,11 @@ class ServerArgs:
default=ServerArgs.stream_interval,
help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
)
parser.add_argument(
"--stream-output",
action="store_true",
help="Whether to output as a sequence of disjoint segments.",
)
parser.add_argument(
"--random-seed",
type=int,
......
......@@ -774,7 +774,7 @@ def get_zmq_socket(
def dump_to_file(dirpath, name, value):
from vllm.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed import get_tensor_model_parallel_rank
if get_tensor_model_parallel_rank() != 0:
return
......
......@@ -34,7 +34,7 @@ DEFAULT_SMALL_MOE_MODEL_NAME_FOR_TEST = "Qwen/Qwen1.5-MoE-A2.7B"
DEFAULT_SMALL_EMBEDDING_MODEL_NAME_FOR_TEST = "Alibaba-NLP/gte-Qwen2-1.5B-instruct"
DEFAULT_MLA_MODEL_NAME_FOR_TEST = "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct"
DEFAULT_MLA_FP8_MODEL_NAME_FOR_TEST = "neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8"
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 600
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 1000
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP1 = "meta-llama/Llama-3.1-8B-Instruct,mistralai/Mistral-7B-Instruct-v0.3,deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct,google/gemma-2-27b-it"
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_TP2 = "meta-llama/Llama-3.1-70B-Instruct,mistralai/Mixtral-8x7B-Instruct-v0.1,Qwen/Qwen2-57B-A14B-Instruct"
DEFAULT_MODEL_NAME_FOR_NIGHTLY_EVAL_FP8_TP1 = "neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8,neuralmagic/Mistral-7B-Instruct-v0.3-FP8,neuralmagic/DeepSeek-Coder-V2-Lite-Instruct-FP8,neuralmagic/gemma-2-2b-it-FP8"
......@@ -135,10 +135,6 @@ def call_generate_srt_raw(prompt, temperature, max_tokens, stop=None, url=None):
return pred
def call_generate_gserver(prompt, temperature, max_tokens, stop=None, url=None):
raise NotImplementedError()
def call_generate_guidance(
prompt, temperature, max_tokens, stop=None, n=1, regex=None, model=None
):
......@@ -530,6 +526,48 @@ def get_similarities(vec1, vec2):
return F.cosine_similarity(torch.tensor(vec1), torch.tensor(vec2), dim=0)
def get_benchmark_args(
base_url="",
dataset_name="",
dataset_path="",
tokenizer="",
num_prompts=500,
random_input_len=4096,
random_output_len=2048,
request_rate=float("inf"),
disable_stream=False,
disable_ignore_eos=False,
):
return SimpleNamespace(
backend="sglang",
base_url=base_url,
host=None,
port=None,
dataset_name=dataset_name,
dataset_path=dataset_path,
model=None,
tokenizer=tokenizer,
num_prompts=num_prompts,
sharegpt_output_len=None,
sharegpt_context_len=None,
random_input_len=random_input_len,
random_output_len=random_output_len,
random_range_ratio=0.0,
request_rate=request_rate,
multi=None,
output_file=None,
disable_tqdm=False,
disable_stream=disable_stream,
return_logprob=False,
seed=0,
disable_ignore_eos=disable_ignore_eos,
extra_request_body=None,
apply_chat_template=False,
profile=None,
lora_name=None,
)
def run_bench_serving(
model,
num_prompts,
......@@ -554,33 +592,17 @@ def run_bench_serving(
)
# Run benchmark
args = SimpleNamespace(
backend="sglang",
args = get_benchmark_args(
base_url=base_url,
host=None,
port=None,
dataset_name=dataset_name,
dataset_path=dataset_path,
model=None,
tokenizer=tokenizer,
num_prompts=num_prompts,
sharegpt_output_len=None,
sharegpt_context_len=None,
random_input_len=random_input_len,
random_output_len=random_output_len,
random_range_ratio=0.0,
request_rate=request_rate,
multi=None,
output_file=None,
disable_tqdm=False,
disable_stream=disable_stream,
return_logprob=False,
seed=0,
disable_ignore_eos=disable_ignore_eos,
extra_request_body=None,
apply_chat_template=False,
profile=None,
lora_name=None,
)
try:
......@@ -596,6 +618,38 @@ def run_bench_serving(
return res
def run_bench_serving_multi(
model,
base_url,
other_server_args,
benchmark_args,
need_warmup=False,
):
# Launch the server
process = popen_launch_server(
model,
base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=other_server_args,
)
# run benchmark for all
res_l = []
try:
for args in benchmark_args:
if need_warmup:
warmup_args = copy.deepcopy(args)
warmup_args.num_prompts = 16
run_benchmark(warmup_args)
res = run_benchmark(args)
res_l.append((args, res))
finally:
kill_process_tree(process.pid)
return res_l
def run_bench_one_batch(model, other_args):
command = [
"python3",
......
......@@ -71,8 +71,8 @@ nvcc_flags = [
"-std=c++17",
"-use_fast_math",
"-DFLASHINFER_ENABLE_F16",
"-Xcompiler",
"-w",
"-Xcompiler=-Wconversion",
"-Xcompiler=-fno-strict-aliasing",
]
nvcc_flags_fp8 = [
"-DFLASHINFER_ENABLE_FP8",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment