Unverified Commit dc1b8bcf authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Format (#593)

parent 5a57b8ad
......@@ -92,4 +92,4 @@ if __name__ == "__main__":
print(ret)
speed = args.batch_size * max_new_tokens / latency
print(f"latency: {latency:.2f} s, speed: {speed:.2f} token/s")
\ No newline at end of file
print(f"latency: {latency:.2f} s, speed: {speed:.2f} token/s")
......@@ -307,8 +307,9 @@ def main(args: argparse.Namespace):
avg_per_output_token_latency = np.mean(
[latency / output_len for _, output_len, latency in REQUEST_LATENCY]
)
decoding_throughput = np.sum([
output_len for _, output_len, _ in REQUEST_LATENCY]) / benchmark_time
decoding_throughput = (
np.sum([output_len for _, output_len, _ in REQUEST_LATENCY]) / benchmark_time
)
print(f"Total time: {benchmark_time:.2f} s")
print(f"Request throughput: {args.num_prompts / benchmark_time:.2f} requests/s")
......
......@@ -48,9 +48,9 @@ def generate_lines(random_words, num_lines, redirect_ratio):
)
for i in redirect_indices:
target_idx = np.random.choice(min(i * 2 + 100, num_lines))
lines[
i
] = f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
lines[i] = (
f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
)
redirects[i] = target_idx
# Build links and find sources
......
......@@ -80,10 +80,12 @@ def main(args):
for i in range(test_df.shape[0]):
prompt_end = format_example(test_df, i, include_answer=False)
arguments.append({
"examples": few_shot_examples,
"question": prompt_end,
})
arguments.append(
{
"examples": few_shot_examples,
"question": prompt_end,
}
)
label = test_df.iloc[i, test_df.shape[1] - 1]
labels.append(label)
......@@ -134,7 +136,9 @@ def main(args):
pt = 0
for subject, num_qs in zip(subjects[: args.nsub], num_questions):
print(f"subject: {subject}, #q:{num_qs}, acc: {np.mean(cors[pt: pt + num_qs]):.3f}")
print(
f"subject: {subject}, #q:{num_qs}, acc: {np.mean(cors[pt: pt + num_qs]):.3f}"
)
pt += num_qs
assert pt == len(cors)
weighted_acc = np.mean(cors)
......
......@@ -108,7 +108,7 @@ def prepare_inputs(bench_args, tokenizer):
for i in range(len(prompts)):
assert len(input_ids[i]) > bench_args.cut_len
tmp_input_ids = input_ids[i][:bench_args.cut_len]
tmp_input_ids = input_ids[i][: bench_args.cut_len]
req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
req.prefix_indices = []
req.sampling_params = sampling_params
......@@ -121,9 +121,9 @@ def prepare_inputs(bench_args, tokenizer):
def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
for i in range(len(reqs)):
req = reqs[i]
req.input_ids += input_ids[i][bench_args.cut_len:]
req.input_ids += input_ids[i][bench_args.cut_len :]
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
i, :bench_args.cut_len
i, : bench_args.cut_len
]
return reqs
......@@ -151,7 +151,8 @@ def extend(reqs, model_runner):
reqs=reqs,
req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool,
tree_cache=None)
tree_cache=None,
)
batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
output = model_runner.forward(batch, ForwardMode.EXTEND)
next_token_ids, _ = batch.sample(output.next_token_logits)
......@@ -212,7 +213,9 @@ def latency_test(
# Load the model
model_runner, tokenizer = load_model(server_args, tp_rank)
print(f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}")
print(
f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
)
# Prepare inputs
reqs = prepare_synthetic_inputs(bench_args, tokenizer)
......@@ -232,7 +235,9 @@ def latency_test(
prefill_latency = time.time() - tic
tot_latency += prefill_latency
throughput = bench_args.input_len * bench_args.batch_size / prefill_latency
rank_print(f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s")
rank_print(
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
)
# Decode
for i in range(output_len):
......@@ -243,13 +248,24 @@ def latency_test(
latency = time.time() - tic
tot_latency += latency
throughput = bench_args.batch_size / latency
if i < 5: rank_print(f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s")
if i < 5:
rank_print(
f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
)
avg_decode_latency = (tot_latency - prefill_latency) / output_len
avg_decode_throughput = bench_args.batch_size / avg_decode_latency
rank_print(f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s")
throughput = (bench_args.input_len + bench_args.output_len) * bench_args.batch_size / tot_latency
rank_print(f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s")
rank_print(
f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s"
)
throughput = (
(bench_args.input_len + bench_args.output_len)
* bench_args.batch_size
/ tot_latency
)
rank_print(
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
)
# Warm up
run_once(4)
......@@ -298,4 +314,4 @@ if __name__ == "__main__":
format="%(message)s",
)
main(server_args, bench_args)
\ No newline at end of file
main(server_args, bench_args)
......@@ -39,4 +39,5 @@ class GlobalConfig:
# This can improve the speed for large batch sizes during prefill.
self.layer_sync_threshold = 8192
global_config = GlobalConfig()
......@@ -185,8 +185,10 @@ class SglFunction:
batch_kwargs = [
{self.arg_names[i]: v for i, v in enumerate(arg_values)}
for arg_values in batch_kwargs
if isinstance(arg_values, (list, tuple)) and
len(self.arg_names) - len(self.arg_defaults) <= len(arg_values) <= len(self.arg_names)
if isinstance(arg_values, (list, tuple))
and len(self.arg_names) - len(self.arg_defaults)
<= len(arg_values)
<= len(self.arg_names)
]
# Ensure to raise an exception if the number of arguments mismatch
if len(batch_kwargs) != num_programs:
......
......@@ -5,13 +5,14 @@ from pydantic import BaseModel
try:
from outlines.caching import cache as disk_cache
from outlines.fsm.guide import RegexGuide
from outlines.caching import disable_cache
from outlines.fsm.guide import RegexGuide
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
from outlines.models.transformers import TransformerTokenizer
except ImportError as e:
print(f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n')
print(
f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n'
)
raise
try:
......
......@@ -264,7 +264,9 @@ class TiktokenTokenizer:
return self.tokenizer.decode_batch(batch)
def apply_chat_template(self, messages, tokenize, add_generation_prompt):
ret = self.chat_template.render(messages=messages, add_generation_prompt=add_generation_prompt)
ret = self.chat_template.render(
messages=messages, add_generation_prompt=add_generation_prompt
)
return self.encode(ret) if tokenize else ret
......@@ -297,5 +299,7 @@ class SentencePieceTokenizer:
return self.tokenizer.decode(batch)
def apply_chat_template(self, messages, tokenize, add_generation_prompt):
ret = self.chat_template.render(messages=messages, add_generation_prompt=add_generation_prompt)
return self.encode(ret) if tokenize else ret
\ No newline at end of file
ret = self.chat_template.render(
messages=messages, add_generation_prompt=add_generation_prompt
)
return self.encode(ret) if tokenize else ret
......@@ -9,7 +9,6 @@ from typing import Any, Dict, Optional, Tuple
import torch
import triton
import triton.language as tl
from vllm import _custom_ops as ops
from vllm.logger import init_logger
......@@ -108,12 +107,16 @@ def fused_moe_kernel(
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_k[None, :] * stride_ak)
a_ptrs = a_ptr + (
offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
)
off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
offs_bn[None, :] * stride_bn)
b_ptrs = (
b_ptr
+ off_experts * stride_be
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
)
if use_fp8:
a_scale = tl.load(a_scale_ptr)
......@@ -129,13 +132,12 @@ def fused_moe_kernel(
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the
# K dimension.
a = tl.load(a_ptrs,
mask=token_mask[:, None] &
(offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0)
b = tl.load(b_ptrs,
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
other=0.0)
a = tl.load(
a_ptrs,
mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0,
)
b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
# We accumulate along the K dimension.
if use_fp8:
accumulator = tl.dot(a, b, acc=accumulator)
......@@ -146,9 +148,7 @@ def fused_moe_kernel(
b_ptrs += BLOCK_SIZE_K * stride_bk
if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token,
mask=token_mask,
other=0)
moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
accumulator = accumulator * moe_weight[:, None]
if use_fp8:
......@@ -158,15 +158,14 @@ def fused_moe_kernel(
# -----------------------------------------------------------
# Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
None, :]
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
def moe_align_block_size(
topk_ids: torch.Tensor, block_size: int,
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
topk_ids: torch.Tensor, block_size: int, num_experts: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Aligns the token distribution across experts to be compatible with block
size for matrix multiplication.
......@@ -205,32 +204,38 @@ def moe_align_block_size(
by block_size for proper block matrix operations.
"""
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty((max_num_tokens_padded, ),
dtype=torch.int32,
device=topk_ids.device)
sorted_ids = torch.empty(
(max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
)
sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
expert_ids = torch.empty((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
num_tokens_post_pad = torch.empty((1),
dtype=torch.int32,
device=topk_ids.device)
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
expert_ids, num_tokens_post_pad)
expert_ids = torch.empty(
(max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
)
num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
ops.moe_align_block_size(
topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
)
return sorted_ids, expert_ids, num_tokens_post_pad
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor, topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool, top_k: int,
config: Dict[str, Any], compute_type: tl.dtype,
use_fp8: bool) -> None:
def invoke_fused_moe_kernel(
A: torch.Tensor,
B: torch.Tensor,
C: torch.Tensor,
A_scale: Optional[torch.Tensor],
B_scale: Optional[torch.Tensor],
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
sorted_token_ids: torch.Tensor,
expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool,
top_k: int,
config: Dict[str, Any],
compute_type: tl.dtype,
use_fp8: bool,
) -> None:
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
......@@ -241,8 +246,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
A, A_scale = ops.scaled_fp8_quant(A, A_scale)
assert B_scale is not None
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
grid = lambda META: (
triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
)
fused_moe_kernel[grid](
A,
......@@ -280,8 +287,7 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
@functools.lru_cache
def get_moe_configs(E: int, N: int,
dtype: Optional[str]) -> Optional[Dict[int, Any]]:
def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
"""
Return optimized configurations for the fused MoE kernel.
......@@ -296,11 +302,11 @@ def get_moe_configs(E: int, N: int,
json_file_name = get_config_file_name(E, N, dtype)
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path):
with open(config_file_path) as f:
logger.info("Using configuration from %s for MoE layer.",
config_file_path)
logger.info("Using configuration from %s for MoE layer.", config_file_path)
# If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()}
......@@ -319,35 +325,35 @@ def get_default_config(
) -> Dict[str, int]:
if dtype == "float8":
config = {
'BLOCK_SIZE_M': 128,
'BLOCK_SIZE_N': 256,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 32,
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 256,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 4
"num_stages": 4,
}
if M <= E:
config = {
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 128,
'BLOCK_SIZE_K': 128,
'GROUP_SIZE_M': 1,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 4
"num_stages": 4,
}
else:
config = {
'BLOCK_SIZE_M': 64,
'BLOCK_SIZE_N': 64,
'BLOCK_SIZE_K': 32,
'GROUP_SIZE_M': 8
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
}
if M <= E:
config = {
'BLOCK_SIZE_M': 16,
'BLOCK_SIZE_N': 32,
'BLOCK_SIZE_K': 64,
'GROUP_SIZE_M': 1
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
}
return config
......@@ -358,23 +364,17 @@ def fused_topk(
topk: int,
renormalize: bool,
):
assert hidden_states.shape[0] == gating_output.shape[0], (
"Number of tokens mismatch")
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
M, _ = hidden_states.shape
topk_weights = torch.empty(M,
topk,
dtype=torch.float32,
device=hidden_states.device)
topk_ids = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
token_expert_indicies = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
topk_weights = torch.empty(
M, topk, dtype=torch.float32, device=hidden_states.device
)
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
token_expert_indicies = torch.empty(
M, topk, dtype=torch.int32, device=hidden_states.device
)
ops.topk_softmax(
topk_weights,
topk_ids,
......@@ -388,27 +388,27 @@ def fused_topk(
return topk_weights, topk_ids
def fused_experts(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None):
def fused_experts(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
override_config: Optional[Dict[str, Any]] = None,
use_fp8: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
):
# Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [
torch.float32, torch.float16, torch.bfloat16
]
assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
M, _ = hidden_states.shape
E, N, _ = w1.shape
......@@ -417,8 +417,7 @@ def fused_experts(hidden_states: torch.Tensor,
config = override_config
else:
# First try to load optimal config from the file
configs = get_moe_configs(E, w2.shape[2],
"float8" if use_fp8 else None)
configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
if configs:
# If an optimal configuration map has been found, look up the
......@@ -426,65 +425,76 @@ def fused_experts(hidden_states: torch.Tensor,
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Else use the default config
config = get_default_config(M, E, N, w1.shape[2],
topk_ids.shape[1],
"float8" if use_fp8 else None)
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype)
config = get_default_config(
M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None
)
intermediate_cache1 = torch.empty(
(M, topk_ids.shape[1], N),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache2 = torch.empty(
(M * topk_ids.shape[1], N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
intermediate_cache3 = torch.empty(
(M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config['BLOCK_SIZE_M'], E)
compute_type = (tl.bfloat16
if hidden_states.dtype == torch.bfloat16 else tl.float16)
invoke_fused_moe_kernel(hidden_states,
w1,
intermediate_cache1,
a1_scale,
w1_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
topk_ids.shape[1],
config,
compute_type=compute_type,
use_fp8=use_fp8)
topk_ids, config["BLOCK_SIZE_M"], E
)
compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
invoke_fused_moe_kernel(
hidden_states,
w1,
intermediate_cache1,
a1_scale,
w1_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
topk_ids.shape[1],
config,
compute_type=compute_type,
use_fp8=use_fp8,
)
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
invoke_fused_moe_kernel(intermediate_cache2,
w2,
intermediate_cache3,
a2_scale,
w2_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
1,
config,
compute_type=compute_type,
use_fp8=use_fp8)
invoke_fused_moe_kernel(
intermediate_cache2,
w2,
intermediate_cache3,
a2_scale,
w2_scale,
topk_weights,
topk_ids,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
1,
config,
compute_type=compute_type,
use_fp8=use_fp8,
)
if inplace:
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=hidden_states)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
dim=1)
return torch.sum(
intermediate_cache3.view(*intermediate_cache3.shape),
dim=1,
out=hidden_states,
)
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
def fused_moe(
......@@ -532,25 +542,28 @@ def fused_moe(
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
if hasattr(ops, "topk_softmax"):
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk,
renormalize)
topk_weights, topk_ids = fused_topk(
hidden_states, gating_output, topk, renormalize
)
else:
topk_weights, topk_ids = fused_topk_v0_4_3(hidden_states, gating_output, topk,
renormalize)
return fused_experts(hidden_states,
w1,
w2,
topk_weights,
topk_ids,
inplace=inplace,
override_config=override_config,
use_fp8=use_fp8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)
topk_weights, topk_ids = fused_topk_v0_4_3(
hidden_states, gating_output, topk, renormalize
)
return fused_experts(
hidden_states,
w1,
w2,
topk_weights,
topk_ids,
inplace=inplace,
override_config=override_config,
use_fp8=use_fp8,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
def fused_topk_v0_4_3(
......@@ -560,6 +573,7 @@ def fused_topk_v0_4_3(
renormalize: bool,
):
import vllm._moe_C as moe_kernels
M, _ = hidden_states.shape
topk_weights = torch.empty(
......@@ -579,4 +593,4 @@ def fused_topk_v0_4_3(
if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids
\ No newline at end of file
return topk_weights, topk_ids
"""Radix attention."""
import numpy as np
import torch
from torch import nn
......@@ -11,8 +12,13 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada
class RadixAttention(nn.Module):
def __init__(
self, num_heads: int, head_dim: int, scaling: float, num_kv_heads: int,
layer_id: int, logit_cap: int = -1
self,
num_heads: int,
head_dim: int,
scaling: float,
num_kv_heads: int,
layer_id: int,
logit_cap: int = -1,
):
super().__init__()
self.tp_q_head_num = num_heads
......@@ -112,6 +118,7 @@ class RadixAttention(nn.Module):
)
from flashinfer.cascade import merge_state
o, _ = merge_state(o1, s1, o2, s2)
if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
......
......@@ -99,4 +99,4 @@ def start_controller_process(
except Exception:
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
finally:
kill_parent_process()
\ No newline at end of file
kill_parent_process()
......@@ -127,7 +127,7 @@ class InputMetadata:
num_qo_heads,
num_kv_heads,
head_dim,
1
1,
)
else:
self.flashinfer_decode_wrapper.end_forward()
......@@ -140,7 +140,7 @@ class InputMetadata:
head_dim,
1,
pos_encoding_mode="NONE",
data_type=self.token_to_kv_pool.kv_data[0].dtype
data_type=self.token_to_kv_pool.kv_data[0].dtype,
)
def init_extend_args(self):
......@@ -228,7 +228,7 @@ class InputMetadata:
ret.init_flashinfer_args(
model_runner.model_config.num_attention_heads // tp_size,
model_runner.model_config.get_num_kv_heads(tp_size),
model_runner.model_config.head_dim
model_runner.model_config.head_dim,
)
return ret
......@@ -269,7 +269,7 @@ class ModelRunner:
world_size=self.tp_size,
rank=self.tp_rank,
local_rank=self.gpu_id,
distributed_init_method=nccl_init_method
distributed_init_method=nccl_init_method,
)
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
total_gpu_memory = get_available_gpu_memory(
......@@ -341,7 +341,13 @@ class ModelRunner:
)
head_dim = self.model_config.head_dim
head_num = self.model_config.get_num_kv_heads(self.tp_size)
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * torch._utils._element_size(self.dtype)
cell_size = (
head_num
* head_dim
* self.model_config.num_hidden_layers
* 2
* torch._utils._element_size(self.dtype)
)
rest_memory = available_gpu_memory - total_gpu_memory * (
1 - self.mem_fraction_static
)
......@@ -384,15 +390,16 @@ class ModelRunner:
def init_flash_infer(self):
if not global_server_args_dict.get("disable_flashinfer", False):
from flashinfer import (
BatchPrefillWithRaggedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
)
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
if not _grouped_size_compiled_for_decode_kernels(
self.model_config.num_attention_heads // self.tp_size,
self.model_config.get_num_kv_heads(self.tp_size)):
self.model_config.get_num_kv_heads(self.tp_size),
):
use_tensor_cores = True
else:
use_tensor_cores = False
......@@ -400,8 +407,8 @@ class ModelRunner:
workspace_buffers = torch.empty(
3, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda"
)
self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper(
workspace_buffers[0], "NHD"
self.flashinfer_prefill_wrapper_ragged = (
BatchPrefillWithRaggedKVCacheWrapper(workspace_buffers[0], "NHD")
)
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
workspace_buffers[1], "NHD"
......@@ -410,7 +417,9 @@ class ModelRunner:
workspace_buffers[2], "NHD", use_tensor_cores=use_tensor_cores
)
else:
self.flashinfer_prefill_wrapper_ragged = self.flashinfer_prefill_wrapper_paged = None
self.flashinfer_prefill_wrapper_ragged = (
self.flashinfer_prefill_wrapper_paged
) = None
self.flashinfer_decode_wrapper = None
@torch.inference_mode()
......
......@@ -34,11 +34,11 @@ from sglang.srt.managers.io_struct import (
from sglang.srt.model_config import ModelConfig
from sglang.srt.server_args import ModelPortArgs, ServerArgs
from sglang.srt.utils import (
connect_rpyc_service,
get_int_token_logit_bias,
is_multimodal_model,
set_random_seed,
start_rpyc_service_process,
connect_rpyc_service,
suppress_other_loggers,
)
from sglang.utils import get_exception_traceback
......@@ -368,9 +368,11 @@ class ModelTpServer:
if (
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
< available_size
and (req.extend_input_len + new_batch_input_tokens
<= self.max_prefill_tokens
or len(can_run_list) == 0)
and (
req.extend_input_len + new_batch_input_tokens
<= self.max_prefill_tokens
or len(can_run_list) == 0
)
):
delta = self.tree_cache.inc_lock_ref(req.last_node)
available_size += delta
......@@ -452,7 +454,9 @@ class ModelTpServer:
next_token_ids,
].tolist()
output.prefill_token_logprobs = output.prefill_token_logprobs.tolist()
output.normalized_prompt_logprobs = output.normalized_prompt_logprobs.tolist()
output.normalized_prompt_logprobs = (
output.normalized_prompt_logprobs.tolist()
)
next_token_ids = next_token_ids.tolist()
else:
......@@ -582,7 +586,9 @@ class ModelTpServer:
req.check_finished()
if req.return_logprob:
req.decode_token_logprobs.append((next_token_logprobs[i], next_token_id))
req.decode_token_logprobs.append(
(next_token_logprobs[i], next_token_id)
)
if req.top_logprobs_num > 0:
req.decode_top_logprobs.append(output.decode_top_logprobs[i])
......@@ -759,16 +765,27 @@ class ModelTpClient:
with ThreadPoolExecutor(self.tp_size) as executor:
# Launch model processes
if server_args.nnodes == 1:
self.procs = list(executor.map(
lambda args: start_rpyc_service_process(*args),
[(ModelTpService, p) for p in model_port_args.model_tp_ports],
))
self.procs = list(
executor.map(
lambda args: start_rpyc_service_process(*args),
[
(ModelTpService, p)
for p in model_port_args.model_tp_ports
],
)
)
addrs = [("localhost", p) for p in model_port_args.model_tp_ports]
else:
addrs = [(ip, port) for ip, port in zip(model_port_args.model_tp_ips, model_port_args.model_tp_ports)]
self.model_services = list(executor.map(
lambda args: connect_rpyc_service(*args), addrs))
addrs = [
(ip, port)
for ip, port in zip(
model_port_args.model_tp_ips, model_port_args.model_tp_ports
)
]
self.model_services = list(
executor.map(lambda args: connect_rpyc_service(*args), addrs)
)
# Init model
def init_model(i):
......
......@@ -334,15 +334,15 @@ class TokenizerManager:
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
)
if top_logprobs_num > 0:
ret["meta_info"][
"prefill_top_logprobs"
] = self.detokenize_top_logprobs_tokens(
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
ret["meta_info"]["prefill_top_logprobs"] = (
self.detokenize_top_logprobs_tokens(
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
)
)
ret["meta_info"][
"decode_top_logprobs"
] = self.detokenize_top_logprobs_tokens(
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
ret["meta_info"]["decode_top_logprobs"] = (
self.detokenize_top_logprobs_tokens(
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
)
)
return ret
......
......@@ -5,19 +5,23 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
import torch
from torch import nn
from transformers import Gemma2Config
from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size
# FIXME: temporary solution, remove after next vllm release
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import GeluAndMul
# from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.linear import (
MergedColumnParallelLinear,
QKVParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
# from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding)
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata
......@@ -26,8 +30,6 @@ from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata
# FIXME: temporary solution, remove after next vllm release
from vllm.model_executor.custom_op import CustomOp
class GemmaRMSNorm(CustomOp):
"""RMS normalization for Gemma.
......@@ -76,13 +78,19 @@ class GemmaRMSNorm(CustomOp):
# FIXME: temporary solution, remove after next vllm release
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
class GemmaRotaryEmbedding(RotaryEmbedding):
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
inv_freq = 1.0 / (base**(
torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float() /
self.rotary_dim))
inv_freq = 1.0 / (
base
** (
torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float()
/ self.rotary_dim
)
)
return inv_freq
......@@ -98,18 +106,17 @@ class Gemma2MLP(nn.Module):
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2,
bias=False,
quant_config=quant_config)
self.down_proj = RowParallelLinear(intermediate_size,
hidden_size,
bias=False,
quant_config=quant_config)
hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
)
self.down_proj = RowParallelLinear(
intermediate_size, hidden_size, bias=False, quant_config=quant_config
)
if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"):
raise ValueError(
"Gemma2 uses `gelu_pytorch_tanh` as the hidden activation "
"function. Please set `hidden_act` and `hidden_activation` to "
"`gelu_pytorch_tanh`.")
"`gelu_pytorch_tanh`."
)
self.act_fn = GeluAndMul(approximate="tanh")
def forward(self, x: torch.Tensor) -> torch.Tensor:
......@@ -121,17 +128,19 @@ class Gemma2MLP(nn.Module):
class Gemma2Attention(nn.Module):
def __init__(self,
layer_idx: int,
config: Gemma2Config,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
max_position_embeddings: int,
rope_theta: float,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None) -> None:
def __init__(
self,
layer_idx: int,
config: Gemma2Config,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
head_dim: int,
max_position_embeddings: int,
rope_theta: float,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__()
self.layer_idx = layer_idx
self.config = config
......@@ -183,15 +192,16 @@ class Gemma2Attention(nn.Module):
# from vLLM: FIXME(woosuk): While Gemma 2 uses sliding window attention for every
# odd layer, vLLM currently ignores it and uses global attention for
# all layers.
use_sliding_window = (layer_idx % 2 == 1
and config.sliding_window is not None)
use_sliding_window = layer_idx % 2 == 1 and config.sliding_window is not None
del use_sliding_window # Unused.
self.attn = RadixAttention(self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_idx,
logit_cap=self.config.attn_logit_softcapping)
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_idx,
logit_cap=self.config.attn_logit_softcapping,
)
def forward(
self,
......@@ -238,14 +248,16 @@ class Gemma2DecoderLayer(nn.Module):
hidden_activation=config.hidden_activation,
quant_config=quant_config,
)
self.input_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.pre_feedforward_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.post_feedforward_layernorm = GemmaRMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = GemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.pre_feedforward_layernorm = GemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.post_feedforward_layernorm = GemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward(
self,
......@@ -258,8 +270,7 @@ class Gemma2DecoderLayer(nn.Module):
residual = hidden_states
hidden_states = self.input_layernorm(hidden_states)
else:
hidden_states, residual = self.input_layernorm(
hidden_states, residual)
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
......@@ -268,7 +279,8 @@ class Gemma2DecoderLayer(nn.Module):
hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, residual = self.pre_feedforward_layernorm(
hidden_states, residual)
hidden_states, residual
)
hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states)
return hidden_states, residual
......@@ -289,10 +301,12 @@ class Gemma2Model(nn.Module):
config.vocab_size,
config.hidden_size,
)
self.layers = nn.ModuleList([
Gemma2DecoderLayer(layer_idx, config, cache_config, quant_config)
for layer_idx in range(config.num_hidden_layers)
])
self.layers = nn.ModuleList(
[
Gemma2DecoderLayer(layer_idx, config, cache_config, quant_config)
for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Normalize the embedding by sqrt(hidden_size)
......@@ -392,7 +406,7 @@ class Gemma2ForCausalLM(nn.Module):
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
for (param_name, shard_name, shard_id) in stacked_params_mapping:
for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name:
continue
name = name.replace(shard_name, param_name)
......@@ -412,8 +426,7 @@ class Gemma2ForCausalLM(nn.Module):
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
......@@ -421,7 +434,8 @@ class Gemma2ForCausalLM(nn.Module):
if unloaded_params:
raise RuntimeError(
"Some weights are not initialized from checkpoints: "
f"{unloaded_params}")
f"{unloaded_params}"
)
EntryClass = Gemma2ForCausalLM
\ No newline at end of file
EntryClass = Gemma2ForCausalLM
......@@ -5,14 +5,12 @@ import tqdm
from torch import nn
from transformers import LlamaConfig
from vllm.config import CacheConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
)
from vllm.distributed import get_tensor_model_parallel_rank
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.managers.controller.model_runner import InputMetadata
from sglang.srt.layers.logits_processor import LogitProcessorOutput
from sglang.srt.managers.controller.model_runner import InputMetadata
from sglang.srt.models.llama2 import LlamaModel
......@@ -28,7 +26,9 @@ class LlamaForClassification(nn.Module):
self.quant_config = quant_config
self.model = LlamaModel(config, quant_config=quant_config)
self.classification_head = nn.Linear(config.hidden_size, config.classification_out_size)
self.classification_head = nn.Linear(
config.hidden_size, config.classification_out_size
)
self.eos_token_id = config.eos_token_id
def forward(
......@@ -45,7 +45,9 @@ class LlamaForClassification(nn.Module):
if scores.shape[0] != input_metadata.batch_size:
print("Warning: the EOS tokens are missing in some sentences.")
scores = torch.ones((input_metadata.batch_size, self.config.classification_out_size)).to(input_ids.device)
scores = torch.ones(
(input_metadata.batch_size, self.config.classification_out_size)
).to(input_ids.device)
return LogitProcessorOutput(
next_token_logits=scores,
......@@ -101,4 +103,5 @@ class LlamaForClassification(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
EntryClass = LlamaForClassification
\ No newline at end of file
EntryClass = LlamaForClassification
......@@ -51,13 +51,12 @@ from sglang.srt.utils import (
allocate_init_ports,
assert_pkg_version,
enable_show_time_cost,
send_addrs_to_rank_0,
receive_addrs,
send_addrs_to_rank_0,
start_rpyc_service_process,
)
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
......@@ -152,9 +151,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
if server_args.disable_disk_cache:
disable_cache()
if not server_args.disable_flashinfer:
assert_pkg_version("flashinfer", "0.0.8", "Please uninstall the old version and "
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.")
assert_pkg_version(
"flashinfer",
"0.0.8",
"Please uninstall the old version and "
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.",
)
if server_args.chat_template:
# TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api(server_args.chat_template)
......@@ -176,7 +179,9 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
ModelPortArgs(
nccl_port=ports[3 + i * (tp_size_local + 1)],
model_tp_ips=[None] * tp_size_local,
model_tp_ports=ports[3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)],
model_tp_ports=ports[
3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)
],
)
)
port_args = PortArgs(
......@@ -194,9 +199,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
else:
receive_addrs(model_port_args[0], server_args)
for i in range(tp_size_local):
start_rpyc_service_process(ModelTpService, model_port_args[0].model_tp_ports[i])
start_rpyc_service_process(
ModelTpService, model_port_args[0].model_tp_ports[i]
)
if server_args.node_rank != 0:
logger.info(f"[node_rank={server_args.node_rank}]: Listen for connections...")
logger.info(
f"[node_rank={server_args.node_rank}]: Listen for connections..."
)
while True:
pass
......
......@@ -137,17 +137,16 @@ class ServerArgs:
"--dtype",
type=str,
default=ServerArgs.dtype,
choices=[
"auto", "half", "float16", "bfloat16", "float", "float32"
],
help='Data type for model weights and activations.\n\n'
choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
help="Data type for model weights and activations.\n\n"
'* "auto" will use FP16 precision for FP32 and FP16 models, and '
'BF16 precision for BF16 models.\n'
"BF16 precision for BF16 models.\n"
'* "half" for FP16. Recommended for AWQ quantization.\n'
'* "float16" is the same as "half".\n'
'* "bfloat16" for a balance between precision and range.\n'
'* "float" is shorthand for FP32 precision.\n'
'* "float32" for FP32 precision.')
'* "float32" for FP32 precision.',
)
parser.add_argument(
"--trust-remote-code",
action="store_true",
......@@ -271,19 +270,12 @@ class ServerArgs:
parser.add_argument(
"--nccl-init-addr",
type=str,
help="The nccl init address of multi-node server."
help="The nccl init address of multi-node server.",
)
parser.add_argument(
"--nnodes",
type=int,
default=1,
help="The number of nodes."
)
parser.add_argument(
"--node-rank",
type=int,
help="The node rank."
"--nnodes", type=int, default=1, help="The number of nodes."
)
parser.add_argument("--node-rank", type=int, help="The node rank.")
# Optimization/debug options
parser.add_argument(
......
......@@ -432,13 +432,12 @@ def assert_pkg_version(pkg: str, min_version: str, message: str):
if pkg_version.parse(installed_version) < pkg_version.parse(min_version):
raise Exception(
f"{pkg} is installed with version {installed_version}, which "
f"is less than the minimum required version {min_version}. " +
message
f"is less than the minimum required version {min_version}. " + message
)
except PackageNotFoundError:
raise Exception(
f"{pkg} with minimum required version {min_version} is not installed. " +
message
f"{pkg} with minimum required version {min_version} is not installed. "
+ message
)
......@@ -474,24 +473,40 @@ def monkey_patch_vllm_dummy_weight_loader():
"""
from vllm.model_executor.model_loader.loader import (
ModelConfig, DeviceConfig, LoRAConfig, VisionLanguageConfig,
ParallelConfig, SchedulerConfig, CacheConfig, nn,
set_default_torch_dtype, _initialize_model, initialize_dummy_weights,
DummyModelLoader
CacheConfig,
DeviceConfig,
DummyModelLoader,
LoRAConfig,
ModelConfig,
ParallelConfig,
SchedulerConfig,
VisionLanguageConfig,
_initialize_model,
initialize_dummy_weights,
nn,
set_default_torch_dtype,
)
def load_model(self, *, model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig) -> nn.Module:
def load_model(
self,
*,
model_config: ModelConfig,
device_config: DeviceConfig,
lora_config: Optional[LoRAConfig],
vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
) -> nn.Module:
with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config,
lora_config, vision_language_config,
cache_config)
model = _initialize_model(
model_config,
self.load_config,
lora_config,
vision_language_config,
cache_config,
)
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
......@@ -541,7 +556,7 @@ def get_ip_address(ifname):
ip_address = fcntl.ioctl(
s.fileno(),
0x8915, # SIOCGIFADDR
struct.pack('256s', bytes(ifname[:15], 'utf-8'))
struct.pack("256s", bytes(ifname[:15], "utf-8")),
)[20:24]
return socket.inet_ntoa(ip_address)
......@@ -550,44 +565,66 @@ def send_addrs_to_rank_0(model_port_args, server_args):
assert server_args.node_rank != 0 and server_args.dp_size == 1
import torch.distributed as dist
ifname = os.environ.get("SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0"))
ifname = os.environ.get(
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
)
ip_addr = get_ip_address(ifname)
num_tp_ports = server_args.tp_size // server_args.nnodes
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
ip_addr = [int(x) for x in ip_addr.split(".")]
addrs_tensor = torch.tensor(ip_addr + model_port_args.model_tp_ports, dtype=torch.int)
addrs_tensor = torch.tensor(
ip_addr + model_port_args.model_tp_ports, dtype=torch.int
)
init_method = f"tcp://{server_args.nccl_init_addr}"
dist.init_process_group(backend="gloo", init_method=init_method, rank=server_args.node_rank, world_size=server_args.nnodes)
dist.init_process_group(
backend="gloo",
init_method=init_method,
rank=server_args.node_rank,
world_size=server_args.nnodes,
)
dist.send(addrs_tensor, dst=0)
print(f"Node {server_args.node_rank} sent: ip_address {ip_addr} and ports {model_port_args.model_tp_ports}")
print(
f"Node {server_args.node_rank} sent: ip_address {ip_addr} and ports {model_port_args.model_tp_ports}"
)
dist.barrier()
dist.destroy_process_group()
dist.destroy_process_group()
def receive_addrs(model_port_args, server_args):
assert server_args.node_rank == 0 and server_args.dp_size == 1
import torch.distributed as dist
ifname = os.environ.get("SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0"))
ifname = os.environ.get(
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
)
ip_addr = get_ip_address(ifname)
num_tp_ports = server_args.tp_size // server_args.nnodes
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
init_method = f"tcp://{server_args.nccl_init_addr}"
dist.init_process_group(backend="gloo", init_method=init_method, rank=server_args.node_rank, world_size=server_args.nnodes)
dist.init_process_group(
backend="gloo",
init_method=init_method,
rank=server_args.node_rank,
world_size=server_args.nnodes,
)
for src_rank in range(1, server_args.nnodes):
tensor = torch.zeros(4 + num_tp_ports, dtype=torch.int)
dist.recv(tensor, src=src_rank)
ip = ".".join([str(x) for x in tensor[:4].tolist()])
ports = tensor[4:].tolist()
model_port_args.model_tp_ips[num_tp_ports * src_rank: num_tp_ports * (src_rank + 1)] = [ip] * num_tp_ports
model_port_args.model_tp_ports[num_tp_ports * src_rank: num_tp_ports * (src_rank + 1)] = ports
model_port_args.model_tp_ips[
num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
] = [ip] * num_tp_ports
model_port_args.model_tp_ports[
num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
] = ports
print(f"Node 0 received from rank {src_rank}: {tensor.tolist()}")
dist.barrier()
dist.destroy_process_group()
dist.destroy_process_group()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment