Unverified Commit dc1b8bcf authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Format (#593)

parent 5a57b8ad
...@@ -92,4 +92,4 @@ if __name__ == "__main__": ...@@ -92,4 +92,4 @@ if __name__ == "__main__":
print(ret) print(ret)
speed = args.batch_size * max_new_tokens / latency speed = args.batch_size * max_new_tokens / latency
print(f"latency: {latency:.2f} s, speed: {speed:.2f} token/s") print(f"latency: {latency:.2f} s, speed: {speed:.2f} token/s")
\ No newline at end of file
...@@ -307,8 +307,9 @@ def main(args: argparse.Namespace): ...@@ -307,8 +307,9 @@ def main(args: argparse.Namespace):
avg_per_output_token_latency = np.mean( avg_per_output_token_latency = np.mean(
[latency / output_len for _, output_len, latency in REQUEST_LATENCY] [latency / output_len for _, output_len, latency in REQUEST_LATENCY]
) )
decoding_throughput = np.sum([ decoding_throughput = (
output_len for _, output_len, _ in REQUEST_LATENCY]) / benchmark_time np.sum([output_len for _, output_len, _ in REQUEST_LATENCY]) / benchmark_time
)
print(f"Total time: {benchmark_time:.2f} s") print(f"Total time: {benchmark_time:.2f} s")
print(f"Request throughput: {args.num_prompts / benchmark_time:.2f} requests/s") print(f"Request throughput: {args.num_prompts / benchmark_time:.2f} requests/s")
......
...@@ -48,9 +48,9 @@ def generate_lines(random_words, num_lines, redirect_ratio): ...@@ -48,9 +48,9 @@ def generate_lines(random_words, num_lines, redirect_ratio):
) )
for i in redirect_indices: for i in redirect_indices:
target_idx = np.random.choice(min(i * 2 + 100, num_lines)) target_idx = np.random.choice(min(i * 2 + 100, num_lines))
lines[ lines[i] = (
i f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}."
] = f"Line {indices[i]}: The REGISTER_CONTENT is the same as Line {indices[target_idx]}." )
redirects[i] = target_idx redirects[i] = target_idx
# Build links and find sources # Build links and find sources
......
...@@ -80,10 +80,12 @@ def main(args): ...@@ -80,10 +80,12 @@ def main(args):
for i in range(test_df.shape[0]): for i in range(test_df.shape[0]):
prompt_end = format_example(test_df, i, include_answer=False) prompt_end = format_example(test_df, i, include_answer=False)
arguments.append({ arguments.append(
"examples": few_shot_examples, {
"question": prompt_end, "examples": few_shot_examples,
}) "question": prompt_end,
}
)
label = test_df.iloc[i, test_df.shape[1] - 1] label = test_df.iloc[i, test_df.shape[1] - 1]
labels.append(label) labels.append(label)
...@@ -134,7 +136,9 @@ def main(args): ...@@ -134,7 +136,9 @@ def main(args):
pt = 0 pt = 0
for subject, num_qs in zip(subjects[: args.nsub], num_questions): for subject, num_qs in zip(subjects[: args.nsub], num_questions):
print(f"subject: {subject}, #q:{num_qs}, acc: {np.mean(cors[pt: pt + num_qs]):.3f}") print(
f"subject: {subject}, #q:{num_qs}, acc: {np.mean(cors[pt: pt + num_qs]):.3f}"
)
pt += num_qs pt += num_qs
assert pt == len(cors) assert pt == len(cors)
weighted_acc = np.mean(cors) weighted_acc = np.mean(cors)
......
...@@ -108,7 +108,7 @@ def prepare_inputs(bench_args, tokenizer): ...@@ -108,7 +108,7 @@ def prepare_inputs(bench_args, tokenizer):
for i in range(len(prompts)): for i in range(len(prompts)):
assert len(input_ids[i]) > bench_args.cut_len assert len(input_ids[i]) > bench_args.cut_len
tmp_input_ids = input_ids[i][:bench_args.cut_len] tmp_input_ids = input_ids[i][: bench_args.cut_len]
req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids) req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
req.prefix_indices = [] req.prefix_indices = []
req.sampling_params = sampling_params req.sampling_params = sampling_params
...@@ -121,9 +121,9 @@ def prepare_inputs(bench_args, tokenizer): ...@@ -121,9 +121,9 @@ def prepare_inputs(bench_args, tokenizer):
def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner): def prepare_extend_inputs(bench_args, input_ids, reqs, model_runner):
for i in range(len(reqs)): for i in range(len(reqs)):
req = reqs[i] req = reqs[i]
req.input_ids += input_ids[i][bench_args.cut_len:] req.input_ids += input_ids[i][bench_args.cut_len :]
req.prefix_indices = model_runner.req_to_token_pool.req_to_token[ req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
i, :bench_args.cut_len i, : bench_args.cut_len
] ]
return reqs return reqs
...@@ -151,7 +151,8 @@ def extend(reqs, model_runner): ...@@ -151,7 +151,8 @@ def extend(reqs, model_runner):
reqs=reqs, reqs=reqs,
req_to_token_pool=model_runner.req_to_token_pool, req_to_token_pool=model_runner.req_to_token_pool,
token_to_kv_pool=model_runner.token_to_kv_pool, token_to_kv_pool=model_runner.token_to_kv_pool,
tree_cache=None) tree_cache=None,
)
batch.prepare_for_extend(model_runner.model_config.vocab_size, None) batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
output = model_runner.forward(batch, ForwardMode.EXTEND) output = model_runner.forward(batch, ForwardMode.EXTEND)
next_token_ids, _ = batch.sample(output.next_token_logits) next_token_ids, _ = batch.sample(output.next_token_logits)
...@@ -212,7 +213,9 @@ def latency_test( ...@@ -212,7 +213,9 @@ def latency_test(
# Load the model # Load the model
model_runner, tokenizer = load_model(server_args, tp_rank) model_runner, tokenizer = load_model(server_args, tp_rank)
print(f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}") print(
f"max_batch_size={model_runner.max_total_num_tokens // (bench_args.input_len + bench_args.output_len)}"
)
# Prepare inputs # Prepare inputs
reqs = prepare_synthetic_inputs(bench_args, tokenizer) reqs = prepare_synthetic_inputs(bench_args, tokenizer)
...@@ -232,7 +235,9 @@ def latency_test( ...@@ -232,7 +235,9 @@ def latency_test(
prefill_latency = time.time() - tic prefill_latency = time.time() - tic
tot_latency += prefill_latency tot_latency += prefill_latency
throughput = bench_args.input_len * bench_args.batch_size / prefill_latency throughput = bench_args.input_len * bench_args.batch_size / prefill_latency
rank_print(f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s") rank_print(
f"Prefill. latency: {prefill_latency:6.5f} s, throughput: {throughput:9.2f} token/s"
)
# Decode # Decode
for i in range(output_len): for i in range(output_len):
...@@ -243,13 +248,24 @@ def latency_test( ...@@ -243,13 +248,24 @@ def latency_test(
latency = time.time() - tic latency = time.time() - tic
tot_latency += latency tot_latency += latency
throughput = bench_args.batch_size / latency throughput = bench_args.batch_size / latency
if i < 5: rank_print(f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s") if i < 5:
rank_print(
f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
)
avg_decode_latency = (tot_latency - prefill_latency) / output_len avg_decode_latency = (tot_latency - prefill_latency) / output_len
avg_decode_throughput = bench_args.batch_size / avg_decode_latency avg_decode_throughput = bench_args.batch_size / avg_decode_latency
rank_print(f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s") rank_print(
f"Decode. avg latency: {avg_decode_latency:6.5f} s, avg throughput: {avg_decode_throughput:9.2f} token/s"
throughput = (bench_args.input_len + bench_args.output_len) * bench_args.batch_size / tot_latency )
rank_print(f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s")
throughput = (
(bench_args.input_len + bench_args.output_len)
* bench_args.batch_size
/ tot_latency
)
rank_print(
f"Total. latency: {tot_latency:6.3f} s, throughput: {throughput:9.2f} token/s"
)
# Warm up # Warm up
run_once(4) run_once(4)
...@@ -298,4 +314,4 @@ if __name__ == "__main__": ...@@ -298,4 +314,4 @@ if __name__ == "__main__":
format="%(message)s", format="%(message)s",
) )
main(server_args, bench_args) main(server_args, bench_args)
\ No newline at end of file
...@@ -39,4 +39,5 @@ class GlobalConfig: ...@@ -39,4 +39,5 @@ class GlobalConfig:
# This can improve the speed for large batch sizes during prefill. # This can improve the speed for large batch sizes during prefill.
self.layer_sync_threshold = 8192 self.layer_sync_threshold = 8192
global_config = GlobalConfig() global_config = GlobalConfig()
...@@ -185,8 +185,10 @@ class SglFunction: ...@@ -185,8 +185,10 @@ class SglFunction:
batch_kwargs = [ batch_kwargs = [
{self.arg_names[i]: v for i, v in enumerate(arg_values)} {self.arg_names[i]: v for i, v in enumerate(arg_values)}
for arg_values in batch_kwargs for arg_values in batch_kwargs
if isinstance(arg_values, (list, tuple)) and if isinstance(arg_values, (list, tuple))
len(self.arg_names) - len(self.arg_defaults) <= len(arg_values) <= len(self.arg_names) and len(self.arg_names) - len(self.arg_defaults)
<= len(arg_values)
<= len(self.arg_names)
] ]
# Ensure to raise an exception if the number of arguments mismatch # Ensure to raise an exception if the number of arguments mismatch
if len(batch_kwargs) != num_programs: if len(batch_kwargs) != num_programs:
......
...@@ -5,13 +5,14 @@ from pydantic import BaseModel ...@@ -5,13 +5,14 @@ from pydantic import BaseModel
try: try:
from outlines.caching import cache as disk_cache from outlines.caching import cache as disk_cache
from outlines.fsm.guide import RegexGuide
from outlines.caching import disable_cache from outlines.caching import disable_cache
from outlines.fsm.guide import RegexGuide from outlines.fsm.guide import RegexGuide
from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
from outlines.models.transformers import TransformerTokenizer from outlines.models.transformers import TransformerTokenizer
except ImportError as e: except ImportError as e:
print(f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n') print(
f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n'
)
raise raise
try: try:
......
...@@ -264,7 +264,9 @@ class TiktokenTokenizer: ...@@ -264,7 +264,9 @@ class TiktokenTokenizer:
return self.tokenizer.decode_batch(batch) return self.tokenizer.decode_batch(batch)
def apply_chat_template(self, messages, tokenize, add_generation_prompt): def apply_chat_template(self, messages, tokenize, add_generation_prompt):
ret = self.chat_template.render(messages=messages, add_generation_prompt=add_generation_prompt) ret = self.chat_template.render(
messages=messages, add_generation_prompt=add_generation_prompt
)
return self.encode(ret) if tokenize else ret return self.encode(ret) if tokenize else ret
...@@ -297,5 +299,7 @@ class SentencePieceTokenizer: ...@@ -297,5 +299,7 @@ class SentencePieceTokenizer:
return self.tokenizer.decode(batch) return self.tokenizer.decode(batch)
def apply_chat_template(self, messages, tokenize, add_generation_prompt): def apply_chat_template(self, messages, tokenize, add_generation_prompt):
ret = self.chat_template.render(messages=messages, add_generation_prompt=add_generation_prompt) ret = self.chat_template.render(
return self.encode(ret) if tokenize else ret messages=messages, add_generation_prompt=add_generation_prompt
\ No newline at end of file )
return self.encode(ret) if tokenize else ret
...@@ -9,7 +9,6 @@ from typing import Any, Dict, Optional, Tuple ...@@ -9,7 +9,6 @@ from typing import Any, Dict, Optional, Tuple
import torch import torch
import triton import triton
import triton.language as tl import triton.language as tl
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -108,12 +107,16 @@ def fused_moe_kernel( ...@@ -108,12 +107,16 @@ def fused_moe_kernel(
offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K) offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + a_ptrs = a_ptr + (
offs_k[None, :] * stride_ak) offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
)
off_experts = tl.load(expert_ids_ptr + pid_m) off_experts = tl.load(expert_ids_ptr + pid_m)
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + b_ptrs = (
offs_bn[None, :] * stride_bn) b_ptr
+ off_experts * stride_be
+ (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
)
if use_fp8: if use_fp8:
a_scale = tl.load(a_scale_ptr) a_scale = tl.load(a_scale_ptr)
...@@ -129,13 +132,12 @@ def fused_moe_kernel( ...@@ -129,13 +132,12 @@ def fused_moe_kernel(
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)): for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
# Load the next block of A and B, generate a mask by checking the # Load the next block of A and B, generate a mask by checking the
# K dimension. # K dimension.
a = tl.load(a_ptrs, a = tl.load(
mask=token_mask[:, None] & a_ptrs,
(offs_k[None, :] < K - k * BLOCK_SIZE_K), mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
other=0.0) other=0.0,
b = tl.load(b_ptrs, )
mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
other=0.0)
# We accumulate along the K dimension. # We accumulate along the K dimension.
if use_fp8: if use_fp8:
accumulator = tl.dot(a, b, acc=accumulator) accumulator = tl.dot(a, b, acc=accumulator)
...@@ -146,9 +148,7 @@ def fused_moe_kernel( ...@@ -146,9 +148,7 @@ def fused_moe_kernel(
b_ptrs += BLOCK_SIZE_K * stride_bk b_ptrs += BLOCK_SIZE_K * stride_bk
if MUL_ROUTED_WEIGHT: if MUL_ROUTED_WEIGHT:
moe_weight = tl.load(topk_weights_ptr + offs_token, moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
mask=token_mask,
other=0)
accumulator = accumulator * moe_weight[:, None] accumulator = accumulator * moe_weight[:, None]
if use_fp8: if use_fp8:
...@@ -158,15 +158,14 @@ def fused_moe_kernel( ...@@ -158,15 +158,14 @@ def fused_moe_kernel(
# ----------------------------------------------------------- # -----------------------------------------------------------
# Write back the block of the output # Write back the block of the output
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N) offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N) c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask) tl.store(c_ptrs, accumulator, mask=c_mask)
def moe_align_block_size( def moe_align_block_size(
topk_ids: torch.Tensor, block_size: int, topk_ids: torch.Tensor, block_size: int, num_experts: int
num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
""" """
Aligns the token distribution across experts to be compatible with block Aligns the token distribution across experts to be compatible with block
size for matrix multiplication. size for matrix multiplication.
...@@ -205,32 +204,38 @@ def moe_align_block_size( ...@@ -205,32 +204,38 @@ def moe_align_block_size(
by block_size for proper block matrix operations. by block_size for proper block matrix operations.
""" """
max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1) max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
sorted_ids = torch.empty((max_num_tokens_padded, ), sorted_ids = torch.empty(
dtype=torch.int32, (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
device=topk_ids.device) )
sorted_ids.fill_(topk_ids.numel()) sorted_ids.fill_(topk_ids.numel())
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
expert_ids = torch.empty((max_num_m_blocks, ), expert_ids = torch.empty(
dtype=torch.int32, (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
device=topk_ids.device) )
num_tokens_post_pad = torch.empty((1), num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
dtype=torch.int32, ops.moe_align_block_size(
device=topk_ids.device) topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids, )
expert_ids, num_tokens_post_pad)
return sorted_ids, expert_ids, num_tokens_post_pad return sorted_ids, expert_ids, num_tokens_post_pad
def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, def invoke_fused_moe_kernel(
A_scale: Optional[torch.Tensor], A: torch.Tensor,
B_scale: Optional[torch.Tensor], B: torch.Tensor,
topk_weights: torch.Tensor, topk_ids: torch.Tensor, C: torch.Tensor,
sorted_token_ids: torch.Tensor, A_scale: Optional[torch.Tensor],
expert_ids: torch.Tensor, B_scale: Optional[torch.Tensor],
num_tokens_post_padded: torch.Tensor, topk_weights: torch.Tensor,
mul_routed_weight: bool, top_k: int, topk_ids: torch.Tensor,
config: Dict[str, Any], compute_type: tl.dtype, sorted_token_ids: torch.Tensor,
use_fp8: bool) -> None: expert_ids: torch.Tensor,
num_tokens_post_padded: torch.Tensor,
mul_routed_weight: bool,
top_k: int,
config: Dict[str, Any],
compute_type: tl.dtype,
use_fp8: bool,
) -> None:
assert topk_weights.stride(1) == 1 assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1 assert sorted_token_ids.stride(0) == 1
...@@ -241,8 +246,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor, ...@@ -241,8 +246,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
A, A_scale = ops.scaled_fp8_quant(A, A_scale) A, A_scale = ops.scaled_fp8_quant(A, A_scale)
assert B_scale is not None assert B_scale is not None
grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[ grid = lambda META: (
'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), ) triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
* triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
)
fused_moe_kernel[grid]( fused_moe_kernel[grid](
A, A,
...@@ -280,8 +287,7 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str: ...@@ -280,8 +287,7 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
@functools.lru_cache @functools.lru_cache
def get_moe_configs(E: int, N: int, def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
dtype: Optional[str]) -> Optional[Dict[int, Any]]:
""" """
Return optimized configurations for the fused MoE kernel. Return optimized configurations for the fused MoE kernel.
...@@ -296,11 +302,11 @@ def get_moe_configs(E: int, N: int, ...@@ -296,11 +302,11 @@ def get_moe_configs(E: int, N: int,
json_file_name = get_config_file_name(E, N, dtype) json_file_name = get_config_file_name(E, N, dtype)
config_file_path = os.path.join( config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
)
if os.path.exists(config_file_path): if os.path.exists(config_file_path):
with open(config_file_path) as f: with open(config_file_path) as f:
logger.info("Using configuration from %s for MoE layer.", logger.info("Using configuration from %s for MoE layer.", config_file_path)
config_file_path)
# If a configuration has been found, return it # If a configuration has been found, return it
return {int(key): val for key, val in json.load(f).items()} return {int(key): val for key, val in json.load(f).items()}
...@@ -319,35 +325,35 @@ def get_default_config( ...@@ -319,35 +325,35 @@ def get_default_config(
) -> Dict[str, int]: ) -> Dict[str, int]:
if dtype == "float8": if dtype == "float8":
config = { config = {
'BLOCK_SIZE_M': 128, "BLOCK_SIZE_M": 128,
'BLOCK_SIZE_N': 256, "BLOCK_SIZE_N": 256,
'BLOCK_SIZE_K': 128, "BLOCK_SIZE_K": 128,
'GROUP_SIZE_M': 32, "GROUP_SIZE_M": 32,
"num_warps": 8, "num_warps": 8,
"num_stages": 4 "num_stages": 4,
} }
if M <= E: if M <= E:
config = { config = {
'BLOCK_SIZE_M': 64, "BLOCK_SIZE_M": 64,
'BLOCK_SIZE_N': 128, "BLOCK_SIZE_N": 128,
'BLOCK_SIZE_K': 128, "BLOCK_SIZE_K": 128,
'GROUP_SIZE_M': 1, "GROUP_SIZE_M": 1,
"num_warps": 4, "num_warps": 4,
"num_stages": 4 "num_stages": 4,
} }
else: else:
config = { config = {
'BLOCK_SIZE_M': 64, "BLOCK_SIZE_M": 64,
'BLOCK_SIZE_N': 64, "BLOCK_SIZE_N": 64,
'BLOCK_SIZE_K': 32, "BLOCK_SIZE_K": 32,
'GROUP_SIZE_M': 8 "GROUP_SIZE_M": 8,
} }
if M <= E: if M <= E:
config = { config = {
'BLOCK_SIZE_M': 16, "BLOCK_SIZE_M": 16,
'BLOCK_SIZE_N': 32, "BLOCK_SIZE_N": 32,
'BLOCK_SIZE_K': 64, "BLOCK_SIZE_K": 64,
'GROUP_SIZE_M': 1 "GROUP_SIZE_M": 1,
} }
return config return config
...@@ -358,23 +364,17 @@ def fused_topk( ...@@ -358,23 +364,17 @@ def fused_topk(
topk: int, topk: int,
renormalize: bool, renormalize: bool,
): ):
assert hidden_states.shape[0] == gating_output.shape[0], ( assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
"Number of tokens mismatch")
M, _ = hidden_states.shape M, _ = hidden_states.shape
topk_weights = torch.empty(M, topk_weights = torch.empty(
topk, M, topk, dtype=torch.float32, device=hidden_states.device
dtype=torch.float32, )
device=hidden_states.device) topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
topk_ids = torch.empty(M, token_expert_indicies = torch.empty(
topk, M, topk, dtype=torch.int32, device=hidden_states.device
dtype=torch.int32, )
device=hidden_states.device)
token_expert_indicies = torch.empty(M,
topk,
dtype=torch.int32,
device=hidden_states.device)
ops.topk_softmax( ops.topk_softmax(
topk_weights, topk_weights,
topk_ids, topk_ids,
...@@ -388,27 +388,27 @@ def fused_topk( ...@@ -388,27 +388,27 @@ def fused_topk(
return topk_weights, topk_ids return topk_weights, topk_ids
def fused_experts(hidden_states: torch.Tensor, def fused_experts(
w1: torch.Tensor, hidden_states: torch.Tensor,
w2: torch.Tensor, w1: torch.Tensor,
topk_weights: torch.Tensor, w2: torch.Tensor,
topk_ids: torch.Tensor, topk_weights: torch.Tensor,
inplace: bool = False, topk_ids: torch.Tensor,
override_config: Optional[Dict[str, Any]] = None, inplace: bool = False,
use_fp8: bool = False, override_config: Optional[Dict[str, Any]] = None,
w1_scale: Optional[torch.Tensor] = None, use_fp8: bool = False,
w2_scale: Optional[torch.Tensor] = None, w1_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, w2_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None): a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
):
# Check constraints. # Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous"
assert w2.is_contiguous(), "Expert weights2 must be contiguous" assert w2.is_contiguous(), "Expert weights2 must be contiguous"
assert hidden_states.dtype in [ assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
torch.float32, torch.float16, torch.bfloat16
]
M, _ = hidden_states.shape M, _ = hidden_states.shape
E, N, _ = w1.shape E, N, _ = w1.shape
...@@ -417,8 +417,7 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -417,8 +417,7 @@ def fused_experts(hidden_states: torch.Tensor,
config = override_config config = override_config
else: else:
# First try to load optimal config from the file # First try to load optimal config from the file
configs = get_moe_configs(E, w2.shape[2], configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
"float8" if use_fp8 else None)
if configs: if configs:
# If an optimal configuration map has been found, look up the # If an optimal configuration map has been found, look up the
...@@ -426,65 +425,76 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -426,65 +425,76 @@ def fused_experts(hidden_states: torch.Tensor,
config = configs[min(configs.keys(), key=lambda x: abs(x - M))] config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else: else:
# Else use the default config # Else use the default config
config = get_default_config(M, E, N, w1.shape[2], config = get_default_config(
topk_ids.shape[1], M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None
"float8" if use_fp8 else None) )
intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N), intermediate_cache1 = torch.empty(
device=hidden_states.device, (M, topk_ids.shape[1], N),
dtype=hidden_states.dtype) device=hidden_states.device,
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), dtype=hidden_states.dtype,
device=hidden_states.device, )
dtype=hidden_states.dtype) intermediate_cache2 = torch.empty(
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]), (M * topk_ids.shape[1], N // 2),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype) dtype=hidden_states.dtype,
)
intermediate_cache3 = torch.empty(
(M, topk_ids.shape[1], w2.shape[1]),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size( sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
topk_ids, config['BLOCK_SIZE_M'], E) topk_ids, config["BLOCK_SIZE_M"], E
compute_type = (tl.bfloat16 )
if hidden_states.dtype == torch.bfloat16 else tl.float16) compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
invoke_fused_moe_kernel(hidden_states, invoke_fused_moe_kernel(
w1, hidden_states,
intermediate_cache1, w1,
a1_scale, intermediate_cache1,
w1_scale, a1_scale,
topk_weights, w1_scale,
topk_ids, topk_weights,
sorted_token_ids, topk_ids,
expert_ids, sorted_token_ids,
num_tokens_post_padded, expert_ids,
False, num_tokens_post_padded,
topk_ids.shape[1], False,
config, topk_ids.shape[1],
compute_type=compute_type, config,
use_fp8=use_fp8) compute_type=compute_type,
use_fp8=use_fp8,
)
ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N)) ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
invoke_fused_moe_kernel(intermediate_cache2, invoke_fused_moe_kernel(
w2, intermediate_cache2,
intermediate_cache3, w2,
a2_scale, intermediate_cache3,
w2_scale, a2_scale,
topk_weights, w2_scale,
topk_ids, topk_weights,
sorted_token_ids, topk_ids,
expert_ids, sorted_token_ids,
num_tokens_post_padded, expert_ids,
True, num_tokens_post_padded,
1, True,
config, 1,
compute_type=compute_type, config,
use_fp8=use_fp8) compute_type=compute_type,
use_fp8=use_fp8,
)
if inplace: if inplace:
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), return torch.sum(
dim=1, intermediate_cache3.view(*intermediate_cache3.shape),
out=hidden_states) dim=1,
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), out=hidden_states,
dim=1) )
return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
def fused_moe( def fused_moe(
...@@ -532,25 +542,28 @@ def fused_moe( ...@@ -532,25 +542,28 @@ def fused_moe(
assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch" assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
if hasattr(ops, "topk_softmax"): if hasattr(ops, "topk_softmax"):
topk_weights, topk_ids = fused_topk(hidden_states, gating_output, topk, topk_weights, topk_ids = fused_topk(
renormalize) hidden_states, gating_output, topk, renormalize
)
else: else:
topk_weights, topk_ids = fused_topk_v0_4_3(hidden_states, gating_output, topk, topk_weights, topk_ids = fused_topk_v0_4_3(
renormalize) hidden_states, gating_output, topk, renormalize
)
return fused_experts(hidden_states,
w1, return fused_experts(
w2, hidden_states,
topk_weights, w1,
topk_ids, w2,
inplace=inplace, topk_weights,
override_config=override_config, topk_ids,
use_fp8=use_fp8, inplace=inplace,
w1_scale=w1_scale, override_config=override_config,
w2_scale=w2_scale, use_fp8=use_fp8,
a1_scale=a1_scale, w1_scale=w1_scale,
a2_scale=a2_scale) w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
)
def fused_topk_v0_4_3( def fused_topk_v0_4_3(
...@@ -560,6 +573,7 @@ def fused_topk_v0_4_3( ...@@ -560,6 +573,7 @@ def fused_topk_v0_4_3(
renormalize: bool, renormalize: bool,
): ):
import vllm._moe_C as moe_kernels import vllm._moe_C as moe_kernels
M, _ = hidden_states.shape M, _ = hidden_states.shape
topk_weights = torch.empty( topk_weights = torch.empty(
...@@ -579,4 +593,4 @@ def fused_topk_v0_4_3( ...@@ -579,4 +593,4 @@ def fused_topk_v0_4_3(
if renormalize: if renormalize:
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True) topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
return topk_weights, topk_ids return topk_weights, topk_ids
\ No newline at end of file
"""Radix attention.""" """Radix attention."""
import numpy as np import numpy as np
import torch import torch
from torch import nn from torch import nn
...@@ -11,8 +12,13 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada ...@@ -11,8 +12,13 @@ from sglang.srt.managers.controller.model_runner import ForwardMode, InputMetada
class RadixAttention(nn.Module): class RadixAttention(nn.Module):
def __init__( def __init__(
self, num_heads: int, head_dim: int, scaling: float, num_kv_heads: int, self,
layer_id: int, logit_cap: int = -1 num_heads: int,
head_dim: int,
scaling: float,
num_kv_heads: int,
layer_id: int,
logit_cap: int = -1,
): ):
super().__init__() super().__init__()
self.tp_q_head_num = num_heads self.tp_q_head_num = num_heads
...@@ -112,6 +118,7 @@ class RadixAttention(nn.Module): ...@@ -112,6 +118,7 @@ class RadixAttention(nn.Module):
) )
from flashinfer.cascade import merge_state from flashinfer.cascade import merge_state
o, _ = merge_state(o1, s1, o2, s2) o, _ = merge_state(o1, s1, o2, s2)
if input_metadata.total_num_tokens >= global_config.layer_sync_threshold: if input_metadata.total_num_tokens >= global_config.layer_sync_threshold:
......
...@@ -99,4 +99,4 @@ def start_controller_process( ...@@ -99,4 +99,4 @@ def start_controller_process(
except Exception: except Exception:
logger.error("Exception in ControllerSingle:\n" + get_exception_traceback()) logger.error("Exception in ControllerSingle:\n" + get_exception_traceback())
finally: finally:
kill_parent_process() kill_parent_process()
\ No newline at end of file
...@@ -127,7 +127,7 @@ class InputMetadata: ...@@ -127,7 +127,7 @@ class InputMetadata:
num_qo_heads, num_qo_heads,
num_kv_heads, num_kv_heads,
head_dim, head_dim,
1 1,
) )
else: else:
self.flashinfer_decode_wrapper.end_forward() self.flashinfer_decode_wrapper.end_forward()
...@@ -140,7 +140,7 @@ class InputMetadata: ...@@ -140,7 +140,7 @@ class InputMetadata:
head_dim, head_dim,
1, 1,
pos_encoding_mode="NONE", pos_encoding_mode="NONE",
data_type=self.token_to_kv_pool.kv_data[0].dtype data_type=self.token_to_kv_pool.kv_data[0].dtype,
) )
def init_extend_args(self): def init_extend_args(self):
...@@ -228,7 +228,7 @@ class InputMetadata: ...@@ -228,7 +228,7 @@ class InputMetadata:
ret.init_flashinfer_args( ret.init_flashinfer_args(
model_runner.model_config.num_attention_heads // tp_size, model_runner.model_config.num_attention_heads // tp_size,
model_runner.model_config.get_num_kv_heads(tp_size), model_runner.model_config.get_num_kv_heads(tp_size),
model_runner.model_config.head_dim model_runner.model_config.head_dim,
) )
return ret return ret
...@@ -269,7 +269,7 @@ class ModelRunner: ...@@ -269,7 +269,7 @@ class ModelRunner:
world_size=self.tp_size, world_size=self.tp_size,
rank=self.tp_rank, rank=self.tp_rank,
local_rank=self.gpu_id, local_rank=self.gpu_id,
distributed_init_method=nccl_init_method distributed_init_method=nccl_init_method,
) )
initialize_model_parallel(tensor_model_parallel_size=self.tp_size) initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
total_gpu_memory = get_available_gpu_memory( total_gpu_memory = get_available_gpu_memory(
...@@ -341,7 +341,13 @@ class ModelRunner: ...@@ -341,7 +341,13 @@ class ModelRunner:
) )
head_dim = self.model_config.head_dim head_dim = self.model_config.head_dim
head_num = self.model_config.get_num_kv_heads(self.tp_size) head_num = self.model_config.get_num_kv_heads(self.tp_size)
cell_size = head_num * head_dim * self.model_config.num_hidden_layers * 2 * torch._utils._element_size(self.dtype) cell_size = (
head_num
* head_dim
* self.model_config.num_hidden_layers
* 2
* torch._utils._element_size(self.dtype)
)
rest_memory = available_gpu_memory - total_gpu_memory * ( rest_memory = available_gpu_memory - total_gpu_memory * (
1 - self.mem_fraction_static 1 - self.mem_fraction_static
) )
...@@ -384,15 +390,16 @@ class ModelRunner: ...@@ -384,15 +390,16 @@ class ModelRunner:
def init_flash_infer(self): def init_flash_infer(self):
if not global_server_args_dict.get("disable_flashinfer", False): if not global_server_args_dict.get("disable_flashinfer", False):
from flashinfer import ( from flashinfer import (
BatchPrefillWithRaggedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
BatchDecodeWithPagedKVCacheWrapper, BatchDecodeWithPagedKVCacheWrapper,
BatchPrefillWithPagedKVCacheWrapper,
BatchPrefillWithRaggedKVCacheWrapper,
) )
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
if not _grouped_size_compiled_for_decode_kernels( if not _grouped_size_compiled_for_decode_kernels(
self.model_config.num_attention_heads // self.tp_size, self.model_config.num_attention_heads // self.tp_size,
self.model_config.get_num_kv_heads(self.tp_size)): self.model_config.get_num_kv_heads(self.tp_size),
):
use_tensor_cores = True use_tensor_cores = True
else: else:
use_tensor_cores = False use_tensor_cores = False
...@@ -400,8 +407,8 @@ class ModelRunner: ...@@ -400,8 +407,8 @@ class ModelRunner:
workspace_buffers = torch.empty( workspace_buffers = torch.empty(
3, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda" 3, 96 * 1024 * 1024, dtype=torch.uint8, device="cuda"
) )
self.flashinfer_prefill_wrapper_ragged = BatchPrefillWithRaggedKVCacheWrapper( self.flashinfer_prefill_wrapper_ragged = (
workspace_buffers[0], "NHD" BatchPrefillWithRaggedKVCacheWrapper(workspace_buffers[0], "NHD")
) )
self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper( self.flashinfer_prefill_wrapper_paged = BatchPrefillWithPagedKVCacheWrapper(
workspace_buffers[1], "NHD" workspace_buffers[1], "NHD"
...@@ -410,7 +417,9 @@ class ModelRunner: ...@@ -410,7 +417,9 @@ class ModelRunner:
workspace_buffers[2], "NHD", use_tensor_cores=use_tensor_cores workspace_buffers[2], "NHD", use_tensor_cores=use_tensor_cores
) )
else: else:
self.flashinfer_prefill_wrapper_ragged = self.flashinfer_prefill_wrapper_paged = None self.flashinfer_prefill_wrapper_ragged = (
self.flashinfer_prefill_wrapper_paged
) = None
self.flashinfer_decode_wrapper = None self.flashinfer_decode_wrapper = None
@torch.inference_mode() @torch.inference_mode()
......
...@@ -34,11 +34,11 @@ from sglang.srt.managers.io_struct import ( ...@@ -34,11 +34,11 @@ from sglang.srt.managers.io_struct import (
from sglang.srt.model_config import ModelConfig from sglang.srt.model_config import ModelConfig
from sglang.srt.server_args import ModelPortArgs, ServerArgs from sglang.srt.server_args import ModelPortArgs, ServerArgs
from sglang.srt.utils import ( from sglang.srt.utils import (
connect_rpyc_service,
get_int_token_logit_bias, get_int_token_logit_bias,
is_multimodal_model, is_multimodal_model,
set_random_seed, set_random_seed,
start_rpyc_service_process, start_rpyc_service_process,
connect_rpyc_service,
suppress_other_loggers, suppress_other_loggers,
) )
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
...@@ -368,9 +368,11 @@ class ModelTpServer: ...@@ -368,9 +368,11 @@ class ModelTpServer:
if ( if (
req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
< available_size < available_size
and (req.extend_input_len + new_batch_input_tokens and (
<= self.max_prefill_tokens req.extend_input_len + new_batch_input_tokens
or len(can_run_list) == 0) <= self.max_prefill_tokens
or len(can_run_list) == 0
)
): ):
delta = self.tree_cache.inc_lock_ref(req.last_node) delta = self.tree_cache.inc_lock_ref(req.last_node)
available_size += delta available_size += delta
...@@ -452,7 +454,9 @@ class ModelTpServer: ...@@ -452,7 +454,9 @@ class ModelTpServer:
next_token_ids, next_token_ids,
].tolist() ].tolist()
output.prefill_token_logprobs = output.prefill_token_logprobs.tolist() output.prefill_token_logprobs = output.prefill_token_logprobs.tolist()
output.normalized_prompt_logprobs = output.normalized_prompt_logprobs.tolist() output.normalized_prompt_logprobs = (
output.normalized_prompt_logprobs.tolist()
)
next_token_ids = next_token_ids.tolist() next_token_ids = next_token_ids.tolist()
else: else:
...@@ -582,7 +586,9 @@ class ModelTpServer: ...@@ -582,7 +586,9 @@ class ModelTpServer:
req.check_finished() req.check_finished()
if req.return_logprob: if req.return_logprob:
req.decode_token_logprobs.append((next_token_logprobs[i], next_token_id)) req.decode_token_logprobs.append(
(next_token_logprobs[i], next_token_id)
)
if req.top_logprobs_num > 0: if req.top_logprobs_num > 0:
req.decode_top_logprobs.append(output.decode_top_logprobs[i]) req.decode_top_logprobs.append(output.decode_top_logprobs[i])
...@@ -759,16 +765,27 @@ class ModelTpClient: ...@@ -759,16 +765,27 @@ class ModelTpClient:
with ThreadPoolExecutor(self.tp_size) as executor: with ThreadPoolExecutor(self.tp_size) as executor:
# Launch model processes # Launch model processes
if server_args.nnodes == 1: if server_args.nnodes == 1:
self.procs = list(executor.map( self.procs = list(
lambda args: start_rpyc_service_process(*args), executor.map(
[(ModelTpService, p) for p in model_port_args.model_tp_ports], lambda args: start_rpyc_service_process(*args),
)) [
(ModelTpService, p)
for p in model_port_args.model_tp_ports
],
)
)
addrs = [("localhost", p) for p in model_port_args.model_tp_ports] addrs = [("localhost", p) for p in model_port_args.model_tp_ports]
else: else:
addrs = [(ip, port) for ip, port in zip(model_port_args.model_tp_ips, model_port_args.model_tp_ports)] addrs = [
(ip, port)
self.model_services = list(executor.map( for ip, port in zip(
lambda args: connect_rpyc_service(*args), addrs)) model_port_args.model_tp_ips, model_port_args.model_tp_ports
)
]
self.model_services = list(
executor.map(lambda args: connect_rpyc_service(*args), addrs)
)
# Init model # Init model
def init_model(i): def init_model(i):
......
...@@ -334,15 +334,15 @@ class TokenizerManager: ...@@ -334,15 +334,15 @@ class TokenizerManager:
ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs ret["meta_info"]["decode_token_logprobs"], return_text_in_logprobs
) )
if top_logprobs_num > 0: if top_logprobs_num > 0:
ret["meta_info"][ ret["meta_info"]["prefill_top_logprobs"] = (
"prefill_top_logprobs" self.detokenize_top_logprobs_tokens(
] = self.detokenize_top_logprobs_tokens( ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs
ret["meta_info"]["prefill_top_logprobs"], return_text_in_logprobs )
) )
ret["meta_info"][ ret["meta_info"]["decode_top_logprobs"] = (
"decode_top_logprobs" self.detokenize_top_logprobs_tokens(
] = self.detokenize_top_logprobs_tokens( ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs
ret["meta_info"]["decode_top_logprobs"], return_text_in_logprobs )
) )
return ret return ret
......
...@@ -5,19 +5,23 @@ from typing import Iterable, List, Optional, Set, Tuple, Union ...@@ -5,19 +5,23 @@ from typing import Iterable, List, Optional, Set, Tuple, Union
import torch import torch
from torch import nn from torch import nn
from transformers import Gemma2Config from transformers import Gemma2Config
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_tensor_model_parallel_world_size from vllm.distributed import get_tensor_model_parallel_world_size
# FIXME: temporary solution, remove after next vllm release
from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.activation import GeluAndMul
# from vllm.model_executor.layers.layernorm import GemmaRMSNorm # from vllm.model_executor.layers.layernorm import GemmaRMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, from vllm.model_executor.layers.linear import (
QKVParallelLinear, MergedColumnParallelLinear,
RowParallelLinear) QKVParallelLinear,
from vllm.model_executor.layers.quantization.base_config import ( RowParallelLinear,
QuantizationConfig) )
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
# from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding # from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
from vllm.model_executor.layers.vocab_parallel_embedding import ( from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
VocabParallelEmbedding)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.sampling_metadata import SamplingMetadata
...@@ -26,8 +30,6 @@ from sglang.srt.layers.radix_attention import RadixAttention ...@@ -26,8 +30,6 @@ from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.managers.controller.model_runner import InputMetadata from sglang.srt.managers.controller.model_runner import InputMetadata
# FIXME: temporary solution, remove after next vllm release
from vllm.model_executor.custom_op import CustomOp
class GemmaRMSNorm(CustomOp): class GemmaRMSNorm(CustomOp):
"""RMS normalization for Gemma. """RMS normalization for Gemma.
...@@ -76,13 +78,19 @@ class GemmaRMSNorm(CustomOp): ...@@ -76,13 +78,19 @@ class GemmaRMSNorm(CustomOp):
# FIXME: temporary solution, remove after next vllm release # FIXME: temporary solution, remove after next vllm release
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
class GemmaRotaryEmbedding(RotaryEmbedding): class GemmaRotaryEmbedding(RotaryEmbedding):
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor: def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107 # https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
inv_freq = 1.0 / (base**( inv_freq = 1.0 / (
torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float() / base
self.rotary_dim)) ** (
torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float()
/ self.rotary_dim
)
)
return inv_freq return inv_freq
...@@ -98,18 +106,17 @@ class Gemma2MLP(nn.Module): ...@@ -98,18 +106,17 @@ class Gemma2MLP(nn.Module):
) -> None: ) -> None:
super().__init__() super().__init__()
self.gate_up_proj = MergedColumnParallelLinear( self.gate_up_proj = MergedColumnParallelLinear(
hidden_size, [intermediate_size] * 2, hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config
bias=False, )
quant_config=quant_config) self.down_proj = RowParallelLinear(
self.down_proj = RowParallelLinear(intermediate_size, intermediate_size, hidden_size, bias=False, quant_config=quant_config
hidden_size, )
bias=False,
quant_config=quant_config)
if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"): if not (hidden_act == hidden_activation == "gelu_pytorch_tanh"):
raise ValueError( raise ValueError(
"Gemma2 uses `gelu_pytorch_tanh` as the hidden activation " "Gemma2 uses `gelu_pytorch_tanh` as the hidden activation "
"function. Please set `hidden_act` and `hidden_activation` to " "function. Please set `hidden_act` and `hidden_activation` to "
"`gelu_pytorch_tanh`.") "`gelu_pytorch_tanh`."
)
self.act_fn = GeluAndMul(approximate="tanh") self.act_fn = GeluAndMul(approximate="tanh")
def forward(self, x: torch.Tensor) -> torch.Tensor: def forward(self, x: torch.Tensor) -> torch.Tensor:
...@@ -121,17 +128,19 @@ class Gemma2MLP(nn.Module): ...@@ -121,17 +128,19 @@ class Gemma2MLP(nn.Module):
class Gemma2Attention(nn.Module): class Gemma2Attention(nn.Module):
def __init__(self, def __init__(
layer_idx: int, self,
config: Gemma2Config, layer_idx: int,
hidden_size: int, config: Gemma2Config,
num_heads: int, hidden_size: int,
num_kv_heads: int, num_heads: int,
head_dim: int, num_kv_heads: int,
max_position_embeddings: int, head_dim: int,
rope_theta: float, max_position_embeddings: int,
cache_config: Optional[CacheConfig] = None, rope_theta: float,
quant_config: Optional[QuantizationConfig] = None) -> None: cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
super().__init__() super().__init__()
self.layer_idx = layer_idx self.layer_idx = layer_idx
self.config = config self.config = config
...@@ -183,15 +192,16 @@ class Gemma2Attention(nn.Module): ...@@ -183,15 +192,16 @@ class Gemma2Attention(nn.Module):
# from vLLM: FIXME(woosuk): While Gemma 2 uses sliding window attention for every # from vLLM: FIXME(woosuk): While Gemma 2 uses sliding window attention for every
# odd layer, vLLM currently ignores it and uses global attention for # odd layer, vLLM currently ignores it and uses global attention for
# all layers. # all layers.
use_sliding_window = (layer_idx % 2 == 1 use_sliding_window = layer_idx % 2 == 1 and config.sliding_window is not None
and config.sliding_window is not None)
del use_sliding_window # Unused. del use_sliding_window # Unused.
self.attn = RadixAttention(self.num_heads, self.attn = RadixAttention(
self.head_dim, self.num_heads,
self.scaling, self.head_dim,
num_kv_heads=self.num_kv_heads, self.scaling,
layer_id=layer_idx, num_kv_heads=self.num_kv_heads,
logit_cap=self.config.attn_logit_softcapping) layer_id=layer_idx,
logit_cap=self.config.attn_logit_softcapping,
)
def forward( def forward(
self, self,
...@@ -238,14 +248,16 @@ class Gemma2DecoderLayer(nn.Module): ...@@ -238,14 +248,16 @@ class Gemma2DecoderLayer(nn.Module):
hidden_activation=config.hidden_activation, hidden_activation=config.hidden_activation,
quant_config=quant_config, quant_config=quant_config,
) )
self.input_layernorm = GemmaRMSNorm(config.hidden_size, self.input_layernorm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
eps=config.rms_norm_eps) self.post_attention_layernorm = GemmaRMSNorm(
self.post_attention_layernorm = GemmaRMSNorm(config.hidden_size, config.hidden_size, eps=config.rms_norm_eps
eps=config.rms_norm_eps) )
self.pre_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, self.pre_feedforward_layernorm = GemmaRMSNorm(
eps=config.rms_norm_eps) config.hidden_size, eps=config.rms_norm_eps
self.post_feedforward_layernorm = GemmaRMSNorm(config.hidden_size, )
eps=config.rms_norm_eps) self.post_feedforward_layernorm = GemmaRMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
def forward( def forward(
self, self,
...@@ -258,8 +270,7 @@ class Gemma2DecoderLayer(nn.Module): ...@@ -258,8 +270,7 @@ class Gemma2DecoderLayer(nn.Module):
residual = hidden_states residual = hidden_states
hidden_states = self.input_layernorm(hidden_states) hidden_states = self.input_layernorm(hidden_states)
else: else:
hidden_states, residual = self.input_layernorm( hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states, residual)
hidden_states = self.self_attn( hidden_states = self.self_attn(
positions=positions, positions=positions,
hidden_states=hidden_states, hidden_states=hidden_states,
...@@ -268,7 +279,8 @@ class Gemma2DecoderLayer(nn.Module): ...@@ -268,7 +279,8 @@ class Gemma2DecoderLayer(nn.Module):
hidden_states = self.post_attention_layernorm(hidden_states) hidden_states = self.post_attention_layernorm(hidden_states)
hidden_states, residual = self.pre_feedforward_layernorm( hidden_states, residual = self.pre_feedforward_layernorm(
hidden_states, residual) hidden_states, residual
)
hidden_states = self.mlp(hidden_states) hidden_states = self.mlp(hidden_states)
hidden_states = self.post_feedforward_layernorm(hidden_states) hidden_states = self.post_feedforward_layernorm(hidden_states)
return hidden_states, residual return hidden_states, residual
...@@ -289,10 +301,12 @@ class Gemma2Model(nn.Module): ...@@ -289,10 +301,12 @@ class Gemma2Model(nn.Module):
config.vocab_size, config.vocab_size,
config.hidden_size, config.hidden_size,
) )
self.layers = nn.ModuleList([ self.layers = nn.ModuleList(
Gemma2DecoderLayer(layer_idx, config, cache_config, quant_config) [
for layer_idx in range(config.num_hidden_layers) Gemma2DecoderLayer(layer_idx, config, cache_config, quant_config)
]) for layer_idx in range(config.num_hidden_layers)
]
)
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# Normalize the embedding by sqrt(hidden_size) # Normalize the embedding by sqrt(hidden_size)
...@@ -392,7 +406,7 @@ class Gemma2ForCausalLM(nn.Module): ...@@ -392,7 +406,7 @@ class Gemma2ForCausalLM(nn.Module):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set() loaded_params: Set[str] = set()
for name, loaded_weight in weights: for name, loaded_weight in weights:
for (param_name, shard_name, shard_id) in stacked_params_mapping: for param_name, shard_name, shard_id in stacked_params_mapping:
if shard_name not in name: if shard_name not in name:
continue continue
name = name.replace(shard_name, param_name) name = name.replace(shard_name, param_name)
...@@ -412,8 +426,7 @@ class Gemma2ForCausalLM(nn.Module): ...@@ -412,8 +426,7 @@ class Gemma2ForCausalLM(nn.Module):
if name.endswith(".bias") and name not in params_dict: if name.endswith(".bias") and name not in params_dict:
continue continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader", default_weight_loader)
default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
...@@ -421,7 +434,8 @@ class Gemma2ForCausalLM(nn.Module): ...@@ -421,7 +434,8 @@ class Gemma2ForCausalLM(nn.Module):
if unloaded_params: if unloaded_params:
raise RuntimeError( raise RuntimeError(
"Some weights are not initialized from checkpoints: " "Some weights are not initialized from checkpoints: "
f"{unloaded_params}") f"{unloaded_params}"
)
EntryClass = Gemma2ForCausalLM EntryClass = Gemma2ForCausalLM
\ No newline at end of file
...@@ -5,14 +5,12 @@ import tqdm ...@@ -5,14 +5,12 @@ import tqdm
from torch import nn from torch import nn
from transformers import LlamaConfig from transformers import LlamaConfig
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import ( from vllm.distributed import get_tensor_model_parallel_rank
get_tensor_model_parallel_rank,
)
from vllm.model_executor.layers.quantization.base_config import QuantizationConfig from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from sglang.srt.managers.controller.model_runner import InputMetadata
from sglang.srt.layers.logits_processor import LogitProcessorOutput from sglang.srt.layers.logits_processor import LogitProcessorOutput
from sglang.srt.managers.controller.model_runner import InputMetadata
from sglang.srt.models.llama2 import LlamaModel from sglang.srt.models.llama2 import LlamaModel
...@@ -28,7 +26,9 @@ class LlamaForClassification(nn.Module): ...@@ -28,7 +26,9 @@ class LlamaForClassification(nn.Module):
self.quant_config = quant_config self.quant_config = quant_config
self.model = LlamaModel(config, quant_config=quant_config) self.model = LlamaModel(config, quant_config=quant_config)
self.classification_head = nn.Linear(config.hidden_size, config.classification_out_size) self.classification_head = nn.Linear(
config.hidden_size, config.classification_out_size
)
self.eos_token_id = config.eos_token_id self.eos_token_id = config.eos_token_id
def forward( def forward(
...@@ -45,7 +45,9 @@ class LlamaForClassification(nn.Module): ...@@ -45,7 +45,9 @@ class LlamaForClassification(nn.Module):
if scores.shape[0] != input_metadata.batch_size: if scores.shape[0] != input_metadata.batch_size:
print("Warning: the EOS tokens are missing in some sentences.") print("Warning: the EOS tokens are missing in some sentences.")
scores = torch.ones((input_metadata.batch_size, self.config.classification_out_size)).to(input_ids.device) scores = torch.ones(
(input_metadata.batch_size, self.config.classification_out_size)
).to(input_ids.device)
return LogitProcessorOutput( return LogitProcessorOutput(
next_token_logits=scores, next_token_logits=scores,
...@@ -101,4 +103,5 @@ class LlamaForClassification(nn.Module): ...@@ -101,4 +103,5 @@ class LlamaForClassification(nn.Module):
weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
EntryClass = LlamaForClassification
\ No newline at end of file EntryClass = LlamaForClassification
...@@ -51,13 +51,12 @@ from sglang.srt.utils import ( ...@@ -51,13 +51,12 @@ from sglang.srt.utils import (
allocate_init_ports, allocate_init_ports,
assert_pkg_version, assert_pkg_version,
enable_show_time_cost, enable_show_time_cost,
send_addrs_to_rank_0,
receive_addrs, receive_addrs,
send_addrs_to_rank_0,
start_rpyc_service_process, start_rpyc_service_process,
) )
from sglang.utils import get_exception_traceback from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
...@@ -152,9 +151,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg ...@@ -152,9 +151,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
if server_args.disable_disk_cache: if server_args.disable_disk_cache:
disable_cache() disable_cache()
if not server_args.disable_flashinfer: if not server_args.disable_flashinfer:
assert_pkg_version("flashinfer", "0.0.8", "Please uninstall the old version and " assert_pkg_version(
"reinstall the latest version by following the instructions " "flashinfer",
"at https://docs.flashinfer.ai/installation.html.") "0.0.8",
"Please uninstall the old version and "
"reinstall the latest version by following the instructions "
"at https://docs.flashinfer.ai/installation.html.",
)
if server_args.chat_template: if server_args.chat_template:
# TODO: replace this with huggingface transformers template # TODO: replace this with huggingface transformers template
load_chat_template_for_openai_api(server_args.chat_template) load_chat_template_for_openai_api(server_args.chat_template)
...@@ -176,7 +179,9 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg ...@@ -176,7 +179,9 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
ModelPortArgs( ModelPortArgs(
nccl_port=ports[3 + i * (tp_size_local + 1)], nccl_port=ports[3 + i * (tp_size_local + 1)],
model_tp_ips=[None] * tp_size_local, model_tp_ips=[None] * tp_size_local,
model_tp_ports=ports[3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)], model_tp_ports=ports[
3 + i * (tp_size_local + 1) + 1 : 3 + (i + 1) * (tp_size_local + 1)
],
) )
) )
port_args = PortArgs( port_args = PortArgs(
...@@ -194,9 +199,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg ...@@ -194,9 +199,13 @@ def launch_server(server_args: ServerArgs, pipe_finish_writer, model_overide_arg
else: else:
receive_addrs(model_port_args[0], server_args) receive_addrs(model_port_args[0], server_args)
for i in range(tp_size_local): for i in range(tp_size_local):
start_rpyc_service_process(ModelTpService, model_port_args[0].model_tp_ports[i]) start_rpyc_service_process(
ModelTpService, model_port_args[0].model_tp_ports[i]
)
if server_args.node_rank != 0: if server_args.node_rank != 0:
logger.info(f"[node_rank={server_args.node_rank}]: Listen for connections...") logger.info(
f"[node_rank={server_args.node_rank}]: Listen for connections..."
)
while True: while True:
pass pass
......
...@@ -137,17 +137,16 @@ class ServerArgs: ...@@ -137,17 +137,16 @@ class ServerArgs:
"--dtype", "--dtype",
type=str, type=str,
default=ServerArgs.dtype, default=ServerArgs.dtype,
choices=[ choices=["auto", "half", "float16", "bfloat16", "float", "float32"],
"auto", "half", "float16", "bfloat16", "float", "float32" help="Data type for model weights and activations.\n\n"
],
help='Data type for model weights and activations.\n\n'
'* "auto" will use FP16 precision for FP32 and FP16 models, and ' '* "auto" will use FP16 precision for FP32 and FP16 models, and '
'BF16 precision for BF16 models.\n' "BF16 precision for BF16 models.\n"
'* "half" for FP16. Recommended for AWQ quantization.\n' '* "half" for FP16. Recommended for AWQ quantization.\n'
'* "float16" is the same as "half".\n' '* "float16" is the same as "half".\n'
'* "bfloat16" for a balance between precision and range.\n' '* "bfloat16" for a balance between precision and range.\n'
'* "float" is shorthand for FP32 precision.\n' '* "float" is shorthand for FP32 precision.\n'
'* "float32" for FP32 precision.') '* "float32" for FP32 precision.',
)
parser.add_argument( parser.add_argument(
"--trust-remote-code", "--trust-remote-code",
action="store_true", action="store_true",
...@@ -271,19 +270,12 @@ class ServerArgs: ...@@ -271,19 +270,12 @@ class ServerArgs:
parser.add_argument( parser.add_argument(
"--nccl-init-addr", "--nccl-init-addr",
type=str, type=str,
help="The nccl init address of multi-node server." help="The nccl init address of multi-node server.",
) )
parser.add_argument( parser.add_argument(
"--nnodes", "--nnodes", type=int, default=1, help="The number of nodes."
type=int,
default=1,
help="The number of nodes."
)
parser.add_argument(
"--node-rank",
type=int,
help="The node rank."
) )
parser.add_argument("--node-rank", type=int, help="The node rank.")
# Optimization/debug options # Optimization/debug options
parser.add_argument( parser.add_argument(
......
...@@ -432,13 +432,12 @@ def assert_pkg_version(pkg: str, min_version: str, message: str): ...@@ -432,13 +432,12 @@ def assert_pkg_version(pkg: str, min_version: str, message: str):
if pkg_version.parse(installed_version) < pkg_version.parse(min_version): if pkg_version.parse(installed_version) < pkg_version.parse(min_version):
raise Exception( raise Exception(
f"{pkg} is installed with version {installed_version}, which " f"{pkg} is installed with version {installed_version}, which "
f"is less than the minimum required version {min_version}. " + f"is less than the minimum required version {min_version}. " + message
message
) )
except PackageNotFoundError: except PackageNotFoundError:
raise Exception( raise Exception(
f"{pkg} with minimum required version {min_version} is not installed. " + f"{pkg} with minimum required version {min_version} is not installed. "
message + message
) )
...@@ -474,24 +473,40 @@ def monkey_patch_vllm_dummy_weight_loader(): ...@@ -474,24 +473,40 @@ def monkey_patch_vllm_dummy_weight_loader():
""" """
from vllm.model_executor.model_loader.loader import ( from vllm.model_executor.model_loader.loader import (
ModelConfig, DeviceConfig, LoRAConfig, VisionLanguageConfig, CacheConfig,
ParallelConfig, SchedulerConfig, CacheConfig, nn, DeviceConfig,
set_default_torch_dtype, _initialize_model, initialize_dummy_weights, DummyModelLoader,
DummyModelLoader LoRAConfig,
ModelConfig,
ParallelConfig,
SchedulerConfig,
VisionLanguageConfig,
_initialize_model,
initialize_dummy_weights,
nn,
set_default_torch_dtype,
) )
def load_model(self, *, model_config: ModelConfig, def load_model(
device_config: DeviceConfig, self,
lora_config: Optional[LoRAConfig], *,
vision_language_config: Optional[VisionLanguageConfig], model_config: ModelConfig,
parallel_config: ParallelConfig, device_config: DeviceConfig,
scheduler_config: SchedulerConfig, lora_config: Optional[LoRAConfig],
cache_config: CacheConfig) -> nn.Module: vision_language_config: Optional[VisionLanguageConfig],
parallel_config: ParallelConfig,
scheduler_config: SchedulerConfig,
cache_config: CacheConfig,
) -> nn.Module:
with set_default_torch_dtype(model_config.dtype): with set_default_torch_dtype(model_config.dtype):
with torch.device(device_config.device): with torch.device(device_config.device):
model = _initialize_model(model_config, self.load_config, model = _initialize_model(
lora_config, vision_language_config, model_config,
cache_config) self.load_config,
lora_config,
vision_language_config,
cache_config,
)
for _, module in model.named_modules(): for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None) quant_method = getattr(module, "quant_method", None)
...@@ -541,7 +556,7 @@ def get_ip_address(ifname): ...@@ -541,7 +556,7 @@ def get_ip_address(ifname):
ip_address = fcntl.ioctl( ip_address = fcntl.ioctl(
s.fileno(), s.fileno(),
0x8915, # SIOCGIFADDR 0x8915, # SIOCGIFADDR
struct.pack('256s', bytes(ifname[:15], 'utf-8')) struct.pack("256s", bytes(ifname[:15], "utf-8")),
)[20:24] )[20:24]
return socket.inet_ntoa(ip_address) return socket.inet_ntoa(ip_address)
...@@ -550,44 +565,66 @@ def send_addrs_to_rank_0(model_port_args, server_args): ...@@ -550,44 +565,66 @@ def send_addrs_to_rank_0(model_port_args, server_args):
assert server_args.node_rank != 0 and server_args.dp_size == 1 assert server_args.node_rank != 0 and server_args.dp_size == 1
import torch.distributed as dist import torch.distributed as dist
ifname = os.environ.get("SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")) ifname = os.environ.get(
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
)
ip_addr = get_ip_address(ifname) ip_addr = get_ip_address(ifname)
num_tp_ports = server_args.tp_size // server_args.nnodes num_tp_ports = server_args.tp_size // server_args.nnodes
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
ip_addr = [int(x) for x in ip_addr.split(".")] ip_addr = [int(x) for x in ip_addr.split(".")]
addrs_tensor = torch.tensor(ip_addr + model_port_args.model_tp_ports, dtype=torch.int) addrs_tensor = torch.tensor(
ip_addr + model_port_args.model_tp_ports, dtype=torch.int
)
init_method = f"tcp://{server_args.nccl_init_addr}" init_method = f"tcp://{server_args.nccl_init_addr}"
dist.init_process_group(backend="gloo", init_method=init_method, rank=server_args.node_rank, world_size=server_args.nnodes) dist.init_process_group(
backend="gloo",
init_method=init_method,
rank=server_args.node_rank,
world_size=server_args.nnodes,
)
dist.send(addrs_tensor, dst=0) dist.send(addrs_tensor, dst=0)
print(f"Node {server_args.node_rank} sent: ip_address {ip_addr} and ports {model_port_args.model_tp_ports}") print(
f"Node {server_args.node_rank} sent: ip_address {ip_addr} and ports {model_port_args.model_tp_ports}"
)
dist.barrier() dist.barrier()
dist.destroy_process_group() dist.destroy_process_group()
def receive_addrs(model_port_args, server_args): def receive_addrs(model_port_args, server_args):
assert server_args.node_rank == 0 and server_args.dp_size == 1 assert server_args.node_rank == 0 and server_args.dp_size == 1
import torch.distributed as dist import torch.distributed as dist
ifname = os.environ.get("SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")) ifname = os.environ.get(
"SGLANG_SOCKET_IFNAME", os.environ.get("NCCL_SOCKET_IFNAME", "eth0")
)
ip_addr = get_ip_address(ifname) ip_addr = get_ip_address(ifname)
num_tp_ports = server_args.tp_size // server_args.nnodes num_tp_ports = server_args.tp_size // server_args.nnodes
model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports model_port_args.model_tp_ips[:num_tp_ports] = [ip_addr] * num_tp_ports
init_method = f"tcp://{server_args.nccl_init_addr}" init_method = f"tcp://{server_args.nccl_init_addr}"
dist.init_process_group(backend="gloo", init_method=init_method, rank=server_args.node_rank, world_size=server_args.nnodes) dist.init_process_group(
backend="gloo",
init_method=init_method,
rank=server_args.node_rank,
world_size=server_args.nnodes,
)
for src_rank in range(1, server_args.nnodes): for src_rank in range(1, server_args.nnodes):
tensor = torch.zeros(4 + num_tp_ports, dtype=torch.int) tensor = torch.zeros(4 + num_tp_ports, dtype=torch.int)
dist.recv(tensor, src=src_rank) dist.recv(tensor, src=src_rank)
ip = ".".join([str(x) for x in tensor[:4].tolist()]) ip = ".".join([str(x) for x in tensor[:4].tolist()])
ports = tensor[4:].tolist() ports = tensor[4:].tolist()
model_port_args.model_tp_ips[num_tp_ports * src_rank: num_tp_ports * (src_rank + 1)] = [ip] * num_tp_ports model_port_args.model_tp_ips[
model_port_args.model_tp_ports[num_tp_ports * src_rank: num_tp_ports * (src_rank + 1)] = ports num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
] = [ip] * num_tp_ports
model_port_args.model_tp_ports[
num_tp_ports * src_rank : num_tp_ports * (src_rank + 1)
] = ports
print(f"Node 0 received from rank {src_rank}: {tensor.tolist()}") print(f"Node 0 received from rank {src_rank}: {tensor.tolist()}")
dist.barrier() dist.barrier()
dist.destroy_process_group() dist.destroy_process_group()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment