Commit 37b63c24 authored by zhuwenwen's avatar zhuwenwen
Browse files

[feat] add nn_moe

parent 2dc7ec2f
...@@ -25,6 +25,7 @@ class BenchmarkConfig(TypedDict): ...@@ -25,6 +25,7 @@ class BenchmarkConfig(TypedDict):
GROUP_SIZE_M: int GROUP_SIZE_M: int
num_warps: int num_warps: int
num_stages: int num_stages: int
num_ldmatrixes: Optional[int]
def benchmark_config( def benchmark_config(
...@@ -38,33 +39,60 @@ def benchmark_config( ...@@ -38,33 +39,60 @@ def benchmark_config(
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
num_iters: int = 100, num_iters: int = 100,
nn_moe: Optional[bool] = False
) -> float: ) -> float:
init_dtype = torch.float16 if use_fp8_w8a8 else dtype init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype) x = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_int8_w8a16: if use_int8_w8a16:
w1 = torch.randint(-127, if not nn_moe:
127, ( w1 = torch.randint(-127,
num_experts, 127, (
shard_intermediate_size, num_experts,
hidden_size, shard_intermediate_size,
), hidden_size,
dtype=torch.int8) ),
w2 = torch.randint(-127, dtype=torch.int8)
127, ( w2 = torch.randint(-127,
num_experts, 127, (
hidden_size, num_experts,
shard_intermediate_size // 2, hidden_size,
), shard_intermediate_size // 2,
dtype=torch.int8) ),
dtype=torch.int8)
else:
w1 = torch.randint(-127,
127, (
num_experts,
hidden_size,
shard_intermediate_size
),
dtype=torch.int8)
w2 = torch.randint(-127,
127, (
num_experts,
shard_intermediate_size // 2,
hidden_size
),
dtype=torch.int8)
else: else:
w1 = torch.randn(num_experts, if not nn_moe:
shard_intermediate_size, w1 = torch.randn(num_experts,
hidden_size, shard_intermediate_size,
dtype=init_dtype) hidden_size,
w2 = torch.randn(num_experts, dtype=init_dtype)
hidden_size, w2 = torch.randn(num_experts,
shard_intermediate_size // 2, hidden_size,
dtype=init_dtype) shard_intermediate_size // 2,
dtype=init_dtype)
else:
w1 = torch.randn(num_experts,
hidden_size,
shard_intermediate_size,
dtype=init_dtype)
w2 = torch.randn(num_experts,
shard_intermediate_size // 2,
hidden_size,
dtype=init_dtype)
gating_output = torch.randn(num_iters, gating_output = torch.randn(num_iters,
num_tokens, num_tokens,
num_experts, num_experts,
...@@ -109,6 +137,7 @@ def benchmark_config( ...@@ -109,6 +137,7 @@ def benchmark_config(
w2_scale=w2_scale, w2_scale=w2_scale,
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
use_nn_moe=nn_moe,
) )
# JIT compilation & warmup # JIT compilation & warmup
...@@ -116,15 +145,16 @@ def benchmark_config( ...@@ -116,15 +145,16 @@ def benchmark_config(
torch.cuda.synchronize() torch.cuda.synchronize()
# Capture 10 invocations with CUDA graph # Capture 10 invocations with CUDA graph
graph = torch.cuda.CUDAGraph() # graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph): # with torch.cuda.graph(graph):
for _ in range(10): # for _ in range(10):
run() # run()
torch.cuda.synchronize() # torch.cuda.synchronize()
# Warmup # Warmup
for _ in range(5): for _ in range(5):
graph.replay() # graph.replay()
run()
torch.cuda.synchronize() torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True) start_event = torch.cuda.Event(enable_timing=True)
...@@ -136,16 +166,17 @@ def benchmark_config( ...@@ -136,16 +166,17 @@ def benchmark_config(
torch.cuda.synchronize() torch.cuda.synchronize()
start_event.record() start_event.record()
graph.replay() # graph.replay()
run()
end_event.record() end_event.record()
end_event.synchronize() end_event.synchronize()
latencies.append(start_event.elapsed_time(end_event)) latencies.append(start_event.elapsed_time(end_event))
avg = sum(latencies) / (num_iters * 10) * 1000 # us avg = sum(latencies) / (num_iters * 10) * 1000 # us
graph.reset() # graph.reset()
return avg return avg
def get_rocm_tuning_space(use_fp16): def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False):
block_mn_range = [16, 32, 64, 128, 256] block_mn_range = [16, 32, 64, 128, 256]
block_k_range = [16, 32, 64, 128, 256] block_k_range = [16, 32, 64, 128, 256]
if not use_fp16: if not use_fp16:
...@@ -166,6 +197,9 @@ def get_rocm_tuning_space(use_fp16): ...@@ -166,6 +197,9 @@ def get_rocm_tuning_space(use_fp16):
"num_stages": num_stage_range, "num_stages": num_stage_range,
"waves_per_eu": waves_per_eu_range, "waves_per_eu": waves_per_eu_range,
} }
if nn_moe:
param_ranges["num_ldmatrixes"] = 1
if use_fp16: if use_fp16:
param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range
param_ranges["kpack"] = kpack_range param_ranges["kpack"] = kpack_range
...@@ -173,11 +207,11 @@ def get_rocm_tuning_space(use_fp16): ...@@ -173,11 +207,11 @@ def get_rocm_tuning_space(use_fp16):
return param_ranges return param_ranges
def get_configs_compute_bound(use_fp16) -> List[Dict[str, int]]: def get_configs_compute_bound(use_fp16, nn_moe: Optional[bool] = False) -> List[Dict[str, int]]:
configs: List[BenchmarkConfig] = [] configs: List[BenchmarkConfig] = []
if current_platform.is_rocm(): if current_platform.is_rocm():
param_ranges = get_rocm_tuning_space(use_fp16) param_ranges = get_rocm_tuning_space(use_fp16, nn_moe)
else: else:
# Reduced search space for faster tuning. # Reduced search space for faster tuning.
# TODO(woosuk): Increase the search space and use a performance model to # TODO(woosuk): Increase the search space and use a performance model to
...@@ -370,6 +404,7 @@ class BenchmarkWorker: ...@@ -370,6 +404,7 @@ class BenchmarkWorker:
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
search_space: List[Dict[str, int]], search_space: List[Dict[str, int]],
nn_moe: Optional[bool] = False
) -> Dict[str, int]: ) -> Dict[str, int]:
best_config = None best_config = None
best_time = float("inf") best_time = float("inf")
...@@ -392,7 +427,8 @@ class BenchmarkWorker: ...@@ -392,7 +427,8 @@ class BenchmarkWorker:
dtype, dtype,
use_fp8_w8a8, use_fp8_w8a8,
use_int8_w8a16, use_int8_w8a16,
num_iters=20) num_iters=20,
nn_moe=nn_moe)
except triton.runtime.autotuner.OutOfResources: except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile. # Some configurations may be invalid and fail to compile.
continue continue
...@@ -407,35 +443,63 @@ class BenchmarkWorker: ...@@ -407,35 +443,63 @@ class BenchmarkWorker:
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig: def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
return {
"BLOCK_SIZE_M": if "num_ldmatrixes" not in config:
config["BLOCK_SIZE_M"], return {
"BLOCK_SIZE_N": "BLOCK_SIZE_M":
config["BLOCK_SIZE_N"], config["BLOCK_SIZE_M"],
"BLOCK_SIZE_K": "BLOCK_SIZE_N":
config["BLOCK_SIZE_K"], config["BLOCK_SIZE_N"],
"GROUP_SIZE_M": "BLOCK_SIZE_K":
config["GROUP_SIZE_M"], config["BLOCK_SIZE_K"],
"num_warps": "GROUP_SIZE_M":
config["num_warps"], config["GROUP_SIZE_M"],
"num_stages": "num_warps":
config["num_stages"], config["num_warps"],
**({ "num_stages":
config["num_stages"],
**({
"waves_per_eu": config["waves_per_eu"] "waves_per_eu": config["waves_per_eu"]
} if "waves_per_eu" in config else {}), } if "waves_per_eu" in config else {}),
**({ **({
"matrix_instr_nonkdim": config["matrix_instr_nonkdim"] "matrix_instr_nonkdim": config["matrix_instr_nonkdim"]
} if "matrix_instr_nonkdim" in config else {}), } if "matrix_instr_nonkdim" in config else {}),
**({ **({
"kpack": config["kpack"] "kpack": config["kpack"]
} if "kpack" in config else {}), } if "kpack" in config else {}),
} }
else:
return {
"BLOCK_SIZE_M":
config["BLOCK_SIZE_M"],
"BLOCK_SIZE_N":
config["BLOCK_SIZE_N"],
"BLOCK_SIZE_K":
config["BLOCK_SIZE_K"],
"GROUP_SIZE_M":
config["GROUP_SIZE_M"],
"num_warps":
config["num_warps"],
"num_stages":
config["num_stages"],
"num_ldmatrixes":
config["num_ldmatrixes"],
**({
"waves_per_eu": config["waves_per_eu"]
} if "waves_per_eu" in config else {}),
**({
"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]
} if "matrix_instr_nonkdim" in config else {}),
**({
"kpack": config["kpack"]
} if "kpack" in config else {}),
}
def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int, def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
shard_intermediate_size: int, hidden_size: int, topk: int, shard_intermediate_size: int, hidden_size: int, topk: int,
dtype: torch.dtype, use_fp8_w8a8: bool, dtype: torch.dtype, use_fp8_w8a8: bool,
use_int8_w8a16: bool) -> None: use_int8_w8a16: bool, use_nn_moe: Optional[bool] = False) -> None:
dtype_str = get_config_dtype_str(dtype, dtype_str = get_config_dtype_str(dtype,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8) use_fp8_w8a8=use_fp8_w8a8)
...@@ -443,7 +507,7 @@ def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int, ...@@ -443,7 +507,7 @@ def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
# NOTE(woosuk): The current naming convention uses w2.shape[2], which # NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul. # is the intermediate size after silu_and_mul.
filename = get_config_file_name(num_experts, shard_intermediate_size // 2, filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
dtype_str) dtype_str, use_nn_moe=use_nn_moe)
print(f"Writing best config to {filename}...") print(f"Writing best config to {filename}...")
with open(filename, "w") as f: with open(filename, "w") as f:
...@@ -466,7 +530,7 @@ def main(args: argparse.Namespace): ...@@ -466,7 +530,7 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "DeepseekV3ForCausalLM": elif config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM":
E = config.n_routed_experts E = config.n_routed_experts
topk = config.num_experts_per_tok topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size intermediate_size = config.moe_intermediate_size
...@@ -510,20 +574,20 @@ def main(args: argparse.Namespace): ...@@ -510,20 +574,20 @@ def main(args: argparse.Namespace):
if args.tune: if args.tune:
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16) is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
search_space = get_configs_compute_bound(is_fp16) search_space = get_configs_compute_bound(is_fp16, args.nn_moe)
print(f"Start tuning over {len(search_space)} configurations...") print(f"Start tuning over {len(search_space)} configurations...")
start = time.time() start = time.time()
configs = _distribute( configs = _distribute(
"tune", [(batch_size, E, shard_intermediate_size, hidden_size, "tune", [(batch_size, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space) topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space, args.nn_moe)
for batch_size in batch_sizes]) for batch_size in batch_sizes])
best_configs = { best_configs = {
M: sort_config(config) M: sort_config(config)
for M, config in zip(batch_sizes, configs) for M, config in zip(batch_sizes, configs)
} }
save_configs(best_configs, E, shard_intermediate_size, hidden_size, save_configs(best_configs, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8_w8a8, use_int8_w8a16) topk, dtype, use_fp8_w8a8, use_int8_w8a16, use_nn_moe=args.nn_moe)
end = time.time() end = time.time()
print(f"Tuning took {end - start:.2f} seconds") print(f"Tuning took {end - start:.2f} seconds")
else: else:
...@@ -554,6 +618,7 @@ if __name__ == "__main__": ...@@ -554,6 +618,7 @@ if __name__ == "__main__":
parser.add_argument("--seed", type=int, default=0) parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False) parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--tune", action="store_true") parser.add_argument("--tune", action="store_true")
parser.add_argument("--nn_moe", type=bool, default=True)
parser.add_argument("--trust-remote-code", action="store_true") parser.add_argument("--trust-remote-code", action="store_true")
args = parser.parse_args() args = parser.parse_args()
......
...@@ -485,13 +485,9 @@ def get_version_add(sha: Optional[str] = None) -> str: ...@@ -485,13 +485,9 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha != 'Unknown': if sha != 'Unknown':
if sha is None: if sha is None:
sha = get_sha(vllm_root) sha = get_sha(vllm_root)
# if (major, minor) == ('2', '3'):
# version = 'das.opt1.' + sha[:7]
if (major, minor) == ('2', '4'): if (major, minor) == ('2', '4'):
version = 'das.opt1.' + sha[:7] version = 'das.opt1.' + sha[:7]
else: else:
# if (major, minor) == ('2', '3'):
# version = 'das.opt1'
if (major, minor) == ('2', '4'): if (major, minor) == ('2', '4'):
version = 'das.opt1' version = 'das.opt1'
......
...@@ -295,14 +295,27 @@ def fused_moe_kernel( ...@@ -295,14 +295,27 @@ def fused_moe_kernel(
# Map program ids `pid` to the block of C it should compute. # Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse. # This is done in a grouped ordering to promote L2 data reuse.
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) # num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) # num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n # num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group # group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M # first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M) # group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m) # pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m # pid_n = (pid % num_pid_in_group) // group_size_m
if GROUP_SIZE_M ==1:
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
else:
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ---------------------------------------------------------- # ----------------------------------------------------------
# Create pointers for the first blocks of A and B. # Create pointers for the first blocks of A and B.
...@@ -479,7 +492,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -479,7 +492,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_fp8_w8a8: bool, use_fp8_w8a8: bool,
use_int8_w8a16: bool, use_int8_w8a16: bool,
use_int4_w4a16: bool, use_int4_w4a16: bool,
block_shape: Optional[List[int]] = None) -> None: block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool]=False) -> None:
assert topk_weights.stride(1) == 1 assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1 assert sorted_token_ids.stride(0) == 1
...@@ -510,7 +524,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -510,7 +524,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
EM = min(sorted_token_ids.shape[0], EM = min(sorted_token_ids.shape[0],
A.shape[0] * top_k * config['BLOCK_SIZE_M']) A.shape[0] * top_k * config['BLOCK_SIZE_M'])
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv( grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(
B.shape[1], META['BLOCK_SIZE_N']), ) B.shape[1] if not use_nn_moe else B.shape[2], META['BLOCK_SIZE_N']), )
if (use_int8_w8a16 or use_int4_w4a16) and \ if (use_int8_w8a16 or use_int4_w4a16) and \
block_shape is not None and block_shape[1] > 0: block_shape is not None and block_shape[1] > 0:
...@@ -566,15 +580,15 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -566,15 +580,15 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
sorted_token_ids, sorted_token_ids,
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
B.shape[1], B.shape[1] if not use_nn_moe else B.shape[2],
A.shape[1], A.shape[1] if not use_nn_moe else A.shape[2],
EM, EM,
topk_ids.numel(), topk_ids.numel(),
A.stride(0), A.stride(0),
A.stride(1), A.stride(1),
B.stride(0), B.stride(0),
B.stride(2), B.stride(2) if not use_nn_moe else B.stride(1),
B.stride(1), B.stride(1) if not use_nn_moe else B.stride(2),
C.stride(1), C.stride(1),
C.stride(2), C.stride(2),
A_scale.stride(0) A_scale.stride(0)
...@@ -602,12 +616,16 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -602,12 +616,16 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
def get_config_file_name(E: int, def get_config_file_name(E: int,
N: int, N: int,
dtype: Optional[str], dtype: Optional[str],
block_shape: Optional[List[int]] = None) -> str: block_shape: Optional[List[int]] = None, use_nn_moe: Optional[bool] = False) -> str:
device_name = current_platform.get_device_name().replace(" ", "_") device_name = current_platform.get_device_name().replace(" ", "_")
dtype_selector = "" if not dtype else f",dtype={dtype}" dtype_selector = "" if not dtype else f",dtype={dtype}"
block_shape_selector = ("" if not block_shape or not all(block_shape) else block_shape_selector = ("" if not block_shape or not all(block_shape) else
f",block_shape={block_shape}") f",block_shape={block_shape}")
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501 if not use_nn_moe:
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501
else:
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}_nn.json"
# Adapted from: https://github.com/sgl-project/sglang/pull/2628 # Adapted from: https://github.com/sgl-project/sglang/pull/2628
...@@ -618,6 +636,7 @@ def get_moe_configs( ...@@ -618,6 +636,7 @@ def get_moe_configs(
dtype: Optional[str], dtype: Optional[str],
block_n: Optional[int] = None, block_n: Optional[int] = None,
block_k: Optional[int] = None, block_k: Optional[int] = None,
use_nn_moe: Optional[bool] = False
) -> Optional[Dict[int, Any]]: ) -> Optional[Dict[int, Any]]:
""" """
Return optimized configurations for the fused MoE kernel. Return optimized configurations for the fused MoE kernel.
...@@ -631,7 +650,7 @@ def get_moe_configs( ...@@ -631,7 +650,7 @@ def get_moe_configs(
# First look up if an optimized configuration is available in the configs # First look up if an optimized configuration is available in the configs
# directory # directory
block_shape = [block_n, block_k] if block_n and block_k else None block_shape = [block_n, block_k] if block_n and block_k else None
json_file_name = get_config_file_name(E, N, dtype, block_shape) json_file_name = get_config_file_name(E, N, dtype, block_shape, use_nn_moe=use_nn_moe)
config_file_path = os.path.join( config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name) os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
...@@ -659,6 +678,7 @@ def get_default_config( ...@@ -659,6 +678,7 @@ def get_default_config(
dtype: Optional[str], dtype: Optional[str],
is_marlin: bool, is_marlin: bool,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool]=False
) -> Dict[str, int]: ) -> Dict[str, int]:
if dtype == "fp8_w8a8" and block_shape is not None: if dtype == "fp8_w8a8" and block_shape is not None:
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0] # Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
...@@ -686,6 +706,8 @@ def get_default_config( ...@@ -686,6 +706,8 @@ def get_default_config(
"BLOCK_SIZE_K": 64, "BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1, "GROUP_SIZE_M": 1,
} }
if use_nn_moe:
config["num_ldmatrixes"] = 1
return config return config
...@@ -697,6 +719,7 @@ def try_get_optimal_moe_config( ...@@ -697,6 +719,7 @@ def try_get_optimal_moe_config(
M: int, M: int,
is_marlin: bool = False, is_marlin: bool = False,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False
): ):
from vllm.model_executor.layers.fused_moe import get_config from vllm.model_executor.layers.fused_moe import get_config
override_config = get_config() override_config = get_config()
...@@ -704,10 +727,13 @@ def try_get_optimal_moe_config( ...@@ -704,10 +727,13 @@ def try_get_optimal_moe_config(
config = override_config config = override_config
else: else:
# First try to load optimal config from the file # First try to load optimal config from the file
E, _, N = w2_shape if not use_nn_moe:
E, _, N = w2_shape
else:
E, N, _ = w2_shape
block_n = block_shape[0] if block_shape else 0 block_n = block_shape[0] if block_shape else 0
block_k = block_shape[1] if block_shape else 0 block_k = block_shape[1] if block_shape else 0
configs = get_moe_configs(E, N, dtype, block_n, block_k) configs = get_moe_configs(E, N, dtype, block_n, block_k, use_nn_moe=use_nn_moe)
if configs: if configs:
# If an optimal configuration map has been found, look up the # If an optimal configuration map has been found, look up the
...@@ -715,8 +741,8 @@ def try_get_optimal_moe_config( ...@@ -715,8 +741,8 @@ def try_get_optimal_moe_config(
config = configs[min(configs.keys(), key=lambda x: abs(x - M))] config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else: else:
# Else use the default config # Else use the default config
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, config = get_default_config(M, E, N, w1_shape[2] if not use_nn_moe else w1_shape[1], top_k, dtype,
is_marlin, block_shape) is_marlin, block_shape, use_nn_moe=use_nn_moe)
return config return config
...@@ -843,10 +869,12 @@ def inplace_fused_experts(hidden_states: torch.Tensor, ...@@ -843,10 +869,12 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
w2_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> None: block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True, fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, w1_scale, use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape) w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape,
use_nn_moe)
def inplace_fused_experts_fake( def inplace_fused_experts_fake(
...@@ -864,7 +892,8 @@ def inplace_fused_experts_fake( ...@@ -864,7 +892,8 @@ def inplace_fused_experts_fake(
w2_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> None: block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> None:
pass pass
...@@ -891,11 +920,13 @@ def outplace_fused_experts( ...@@ -891,11 +920,13 @@ def outplace_fused_experts(
w2_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor: block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, use_fp8_w8a8, use_int8_w8a16, False, use_fp8_w8a8, use_int8_w8a16,
use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp, use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape) a1_scale, a2_scale, block_shape,
use_nn_moe)
def outplace_fused_experts_fake( def outplace_fused_experts_fake(
...@@ -913,7 +944,8 @@ def outplace_fused_experts_fake( ...@@ -913,7 +944,8 @@ def outplace_fused_experts_fake(
w2_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor: block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -940,20 +972,23 @@ def fused_experts(hidden_states: torch.Tensor, ...@@ -940,20 +972,23 @@ def fused_experts(hidden_states: torch.Tensor,
w2_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None): block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False):
if inplace: if inplace:
torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2, torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2,
topk_weights, topk_ids, topk_weights, topk_ids,
use_fp8_w8a8, use_int8_w8a16, use_fp8_w8a8, use_int8_w8a16,
use_int4_w4a16, w1_scale, use_int4_w4a16, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
a2_scale, block_shape) a2_scale, block_shape,
use_nn_moe)
return hidden_states return hidden_states
else: else:
return torch.ops.vllm.outplace_fused_experts( return torch.ops.vllm.outplace_fused_experts(
hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8, hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8,
use_int8_w8a16, use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp, use_int8_w8a16, use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape) a1_scale, a2_scale, block_shape,
use_nn_moe)
def fused_experts_impl(hidden_states: torch.Tensor, def fused_experts_impl(hidden_states: torch.Tensor,
...@@ -971,11 +1006,14 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -971,11 +1006,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
w2_zp: Optional[torch.Tensor] = None, w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None): block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False):
# Check constraints. # Check constraints.
if use_int4_w4a16: if use_int4_w4a16:
assert hidden_states.shape[1] // 2 == w1.shape[ assert hidden_states.shape[1] // 2 == w1.shape[
2], "Hidden size mismatch" 2], "Hidden size mismatch"
elif use_nn_moe:
assert hidden_states.shape[1] == w1.shape[1], "Hidden size mismatch"
else: else:
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
...@@ -988,7 +1026,10 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -988,7 +1026,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
] ]
num_tokens, _ = hidden_states.shape num_tokens, _ = hidden_states.shape
E, N, _ = w1.shape if use_nn_moe:
E, _, N = w1.shape
else:
E, N, _ = w1.shape
# We execute the fused_moe kernel in chunks to circumvent this issue: # We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938 # https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
...@@ -1005,6 +1046,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1005,6 +1046,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
topk_ids.shape[1], topk_ids.shape[1],
config_dtype, config_dtype,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe,
) )
config = get_config_func(M) config = get_config_func(M)
...@@ -1015,7 +1057,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1015,7 +1057,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2), intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype) dtype=hidden_states.dtype)
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]), intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1] if not use_nn_moe else w2.shape[2]),
device=hidden_states.device, device=hidden_states.device,
dtype=hidden_states.dtype) dtype=hidden_states.dtype)
...@@ -1077,7 +1119,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1077,7 +1119,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
block_shape=block_shape) block_shape=block_shape,
use_nn_moe=use_nn_moe)
torch.ops._C.silu_and_mul(intermediate_cache2, torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N)) intermediate_cache1.view(-1, N))
...@@ -1100,7 +1143,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1100,7 +1143,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
block_shape=block_shape) block_shape=block_shape,
use_nn_moe=use_nn_moe)
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx]) out_hidden_states[begin_chunk_idx:end_chunk_idx])
...@@ -1129,6 +1173,7 @@ def fused_moe( ...@@ -1129,6 +1173,7 @@ def fused_moe(
a1_scale: Optional[torch.Tensor] = None, a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
This function computes a Mixture of Experts (MoE) layer using two sets of This function computes a Mixture of Experts (MoE) layer using two sets of
...@@ -1200,4 +1245,5 @@ def fused_moe( ...@@ -1200,4 +1245,5 @@ def fused_moe(
w2_zp=w2_zp, w2_zp=w2_zp,
a1_scale=a1_scale, a1_scale=a1_scale,
a2_scale=a2_scale, a2_scale=a2_scale,
block_shape=block_shape) block_shape=block_shape,
use_nn_moe=use_nn_moe)
import os
from abc import abstractmethod from abc import abstractmethod
from enum import Enum from enum import Enum
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple
...@@ -66,24 +67,41 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -66,24 +67,41 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int, hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, use_nn_moe: bool,
**extra_weight_attrs):
# Fused gate_up_proj (column parallel) # Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty( if not use_nn_moe:
w13_weight = torch.nn.Parameter(torch.empty(
num_experts, num_experts,
2 * intermediate_size_per_partition, 2 * intermediate_size_per_partition,
hidden_size, hidden_size,
dtype=params_dtype), dtype=params_dtype),
requires_grad=False) requires_grad=False)
else:
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
2 * intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight) layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs) set_weight_attrs(w13_weight, extra_weight_attrs)
# down_proj (row parallel) # down_proj (row parallel)
w2_weight = torch.nn.Parameter(torch.empty( if not use_nn_moe:
num_experts, w2_weight = torch.nn.Parameter(torch.empty(
hidden_size, num_experts,
intermediate_size_per_partition, hidden_size,
dtype=params_dtype), intermediate_size_per_partition,
requires_grad=False) dtype=params_dtype),
requires_grad=False)
else:
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight) layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
...@@ -113,7 +131,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -113,7 +131,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None e_score_correction_bias: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False
) -> torch.Tensor: ) -> torch.Tensor:
return self.forward(x=x, return self.forward(x=x,
layer=layer, layer=layer,
...@@ -125,7 +144,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -125,7 +144,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group=num_expert_group, num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function, custom_routing_function=custom_routing_function,
scoring_func=scoring_func, scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias) e_score_correction_bias=e_score_correction_bias,
use_nn_moe=use_nn_moe)
def forward_cuda( def forward_cuda(
self, self,
...@@ -139,7 +159,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -139,7 +159,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax", scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None e_score_correction_bias: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False
) -> torch.Tensor: ) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts( topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x, hidden_states=x,
...@@ -158,7 +179,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -158,7 +179,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
w2=layer.w2_weight, w2=layer.w2_weight,
topk_weights=topk_weights, topk_weights=topk_weights,
topk_ids=topk_ids, topk_ids=topk_ids,
inplace=True) inplace=True,
use_nn_moe=use_nn_moe)
def forward_cpu( def forward_cpu(
self, self,
...@@ -171,6 +193,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -171,6 +193,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_group: Optional[int] = None, topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None, num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None, custom_routing_function: Optional[Callable] = None,
use_nn_moe: Optional[bool] = False,
**kwargs, **kwargs,
): ):
assert custom_routing_function is None assert custom_routing_function is None
...@@ -298,12 +321,19 @@ class FusedMoE(torch.nn.Module): ...@@ -298,12 +321,19 @@ class FusedMoE(torch.nn.Module):
self.intermediate_size_per_partition, self.intermediate_size_per_partition,
"params_dtype": params_dtype, "params_dtype": params_dtype,
"weight_loader": self.weight_loader, "weight_loader": self.weight_loader,
"use_nn_moe":self.use_nn_moe,
} }
# need full intermediate size pre-sharding for WNA16 act order # need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__ == if (self.quant_method.__class__.__name__ ==
"CompressedTensorsWNA16MoEMethod"): "CompressedTensorsWNA16MoEMethod"):
moe_quant_params["intermediate_size_full"] = intermediate_size moe_quant_params["intermediate_size_full"] = intermediate_size
if quant_config is None:
# Not considering quant for now, temporarily
self.use_nn_moe = int(os.environ.get('MOE_NN', 1)) == 1
else:
self.use_nn_moe = False
self.quant_method.create_weights(layer=self, **moe_quant_params) self.quant_method.create_weights(layer=self, **moe_quant_params)
def _load_per_tensor_weight_scale(self, shard_id: str, def _load_per_tensor_weight_scale(self, shard_id: str,
...@@ -372,7 +402,8 @@ class FusedMoE(torch.nn.Module): ...@@ -372,7 +402,8 @@ class FusedMoE(torch.nn.Module):
# Index the loaded weight for tp sharding. # Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim # gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2 shard_size = expert_data.shape[shard_dim] // 2
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank, loaded_weight = loaded_weight.narrow(shard_dim if not self.use_nn_moe else ~shard_dim,
shard_size * tp_rank,
shard_size) shard_size)
# Narrow parameter and load. # Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13. # w1, gate_proj: Load into first logical weight of w13.
...@@ -382,7 +413,10 @@ class FusedMoE(torch.nn.Module): ...@@ -382,7 +413,10 @@ class FusedMoE(torch.nn.Module):
else: else:
assert shard_id == "w3" assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size) expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight) if not self.use_nn_moe:
expert_data.copy_(loaded_weight)
else:
expert_data.copy_(loaded_weight.T)
def _load_w2(self, def _load_w2(self,
expert_data: torch.Tensor, expert_data: torch.Tensor,
...@@ -396,18 +430,24 @@ class FusedMoE(torch.nn.Module): ...@@ -396,18 +430,24 @@ class FusedMoE(torch.nn.Module):
# Narrow parameter and load. # Narrow parameter and load.
shard_size = expert_data.shape[shard_dim] shard_size = expert_data.shape[shard_dim]
if not load_full: if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim, loaded_weight = loaded_weight.narrow(shard_dim if not self.use_nn_moe else ~shard_dim,
shard_size * tp_rank, shard_size * tp_rank,
shard_size) shard_size)
# w2, down_proj: Load into only logical weight of w2. # w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight) if not self.use_nn_moe:
expert_data.copy_(loaded_weight)
else:
expert_data.copy_(loaded_weight.T)
def _load_single_value(self, param: torch.nn.Parameter, def _load_single_value(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, expert_id: int): loaded_weight: torch.Tensor, expert_id: int):
param_data = param.data param_data = param.data
# Input scales can be loaded directly and should be equal. # Input scales can be loaded directly and should be equal.
param_data[expert_id] = loaded_weight if not self.use_nn_moe:
param_data[expert_id] = loaded_weight
else:
param_data[expert_id] = loaded_weight.T
def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor, def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int): shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int):
...@@ -419,7 +459,10 @@ class FusedMoE(torch.nn.Module): ...@@ -419,7 +459,10 @@ class FusedMoE(torch.nn.Module):
tp_rank=tp_rank) tp_rank=tp_rank)
else: else:
assert shard_id in ("w1", "w3") assert shard_id in ("w1", "w3")
expert_data.copy_(loaded_weight) if not self.use_nn_moe:
expert_data.copy_(loaded_weight)
else:
expert_data.copy_(loaded_weight.T)
def weight_loader(self, param: torch.nn.Parameter, def weight_loader(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str, loaded_weight: torch.Tensor, weight_name: str,
...@@ -450,7 +493,7 @@ class FusedMoE(torch.nn.Module): ...@@ -450,7 +493,7 @@ class FusedMoE(torch.nn.Module):
# is_transposed: if the dim to shard the weight # is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors # should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size_per_partition is # should be whatever dimension intermediate_size_per_partition is
is_transposed = getattr(param, "is_transposed", False) is_transposed = getattr(param, "is_transposed", False) or self.use_nn_moe
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id] shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if is_transposed: if is_transposed:
shard_dim = int(not shard_dim) shard_dim = int(not shard_dim)
...@@ -592,7 +635,8 @@ class FusedMoE(torch.nn.Module): ...@@ -592,7 +635,8 @@ class FusedMoE(torch.nn.Module):
num_expert_group=self.num_expert_group, num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function, custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func, scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias) e_score_correction_bias=self.e_score_correction_bias,
use_nn_moe=self.use_nn_moe)
if self.reduce_results and self.tp_size > 1: if self.reduce_results and self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce( final_hidden_states = tensor_model_parallel_all_reduce(
......
...@@ -28,7 +28,8 @@ def get_model_architecture( ...@@ -28,7 +28,8 @@ def get_model_architecture(
architectures = getattr(model_config.hf_config, "architectures", []) architectures = getattr(model_config.hf_config, "architectures", [])
visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", []) visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", [])
# 'Qwen2VLForConditionalGeneration' # 'Qwen2VLForConditionalGeneration'
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'ChatGLMModel', 'ChatGLMForConditionalGeneration', 'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM', 'MLPSpeculatorPreTrainedModel', 'FalconForCausalLM'] support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'ChatGLMModel', 'ChatGLMForConditionalGeneration',
'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM', 'MLPSpeculatorPreTrainedModel', 'FalconForCausalLM', 'DeepseekV3ForCausalLM']
if any(arch in architectures for arch in support_nn_architectures): if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0': if os.getenv('LLAMA_NN') != '0':
if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []: if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []:
......
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
"""Inference-only DeepseekV3 model.""" """Inference-only DeepseekV3 model."""
import os
import re
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch import torch
...@@ -52,6 +54,7 @@ from .interfaces import SupportsPP ...@@ -52,6 +54,7 @@ from .interfaces import SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter, from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers, make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix) maybe_prefix)
from vllm import _custom_ops as ops
class DeepseekV3MLP(nn.Module): class DeepseekV3MLP(nn.Module):
...@@ -666,6 +669,15 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP): ...@@ -666,6 +669,15 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP):
self.sampler = get_sampler() self.sampler = get_sampler()
self.make_empty_intermediate_tensors = ( self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors) self.model.make_empty_intermediate_tensors)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_mla = False
if hasattr(vllm_config.model_config, "use_mla"):
self.use_mla = vllm_config.model_config.use_mla
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids) return self.model.get_input_embeddings(input_ids)
...@@ -800,4 +812,42 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP): ...@@ -800,4 +812,42 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP):
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
loaded_params.add(name) loaded_params.add(name)
if self.use_llama_nn and self.quant_method is None:
lay_key_words = [
"self_attn.q_a_proj.weight",
"self_attn.kv_a_proj_with_mqa.weight",
"mlp.gate.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj",
"shared_experts.gate_up_proj",
"shared_experts.down_proj"
]
if not self.use_mla:
lay_key_words.extend([
"self_attn.q_proj.weight",
"self_attn.q_b_proj.weight",
"self_attn.kv_b_proj.weight",
"self_attn.o_proj.weight",
])
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
if "lm_head.weight" in layername:
lay_key_words.append("lm_head.weight")
combined_words = "|".join(lay_key_words)
os.environ['LM_NN'] = '1'
else:
os.environ['LM_NN'] = '0'
matches = re.findall(combined_words, layername)
if matches:
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1)
return loaded_params return loaded_params
...@@ -521,6 +521,13 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): ...@@ -521,6 +521,13 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
combined_words = "|".join(lay_key_words) combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items(): for layername, weight in params_dict.items():
if "lm_head.weight" in layername:
lay_key_words.append("lm_head.weight")
combined_words = "|".join(lay_key_words)
os.environ['LM_NN'] = '1'
else:
os.environ['LM_NN'] = '0'
matches = re.findall(combined_words, layername) matches = re.findall(combined_words, layername)
if matches: if matches:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment