Commit 37b63c24 authored by zhuwenwen's avatar zhuwenwen
Browse files

[feat] add nn_moe

parent 2dc7ec2f
......@@ -25,6 +25,7 @@ class BenchmarkConfig(TypedDict):
GROUP_SIZE_M: int
num_warps: int
num_stages: int
num_ldmatrixes: Optional[int]
def benchmark_config(
......@@ -38,33 +39,60 @@ def benchmark_config(
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
num_iters: int = 100,
nn_moe: Optional[bool] = False
) -> float:
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
if use_int8_w8a16:
w1 = torch.randint(-127,
127, (
num_experts,
shard_intermediate_size,
hidden_size,
),
dtype=torch.int8)
w2 = torch.randint(-127,
127, (
num_experts,
hidden_size,
shard_intermediate_size // 2,
),
dtype=torch.int8)
if not nn_moe:
w1 = torch.randint(-127,
127, (
num_experts,
shard_intermediate_size,
hidden_size,
),
dtype=torch.int8)
w2 = torch.randint(-127,
127, (
num_experts,
hidden_size,
shard_intermediate_size // 2,
),
dtype=torch.int8)
else:
w1 = torch.randint(-127,
127, (
num_experts,
hidden_size,
shard_intermediate_size
),
dtype=torch.int8)
w2 = torch.randint(-127,
127, (
num_experts,
shard_intermediate_size // 2,
hidden_size
),
dtype=torch.int8)
else:
w1 = torch.randn(num_experts,
shard_intermediate_size,
hidden_size,
dtype=init_dtype)
w2 = torch.randn(num_experts,
hidden_size,
shard_intermediate_size // 2,
dtype=init_dtype)
if not nn_moe:
w1 = torch.randn(num_experts,
shard_intermediate_size,
hidden_size,
dtype=init_dtype)
w2 = torch.randn(num_experts,
hidden_size,
shard_intermediate_size // 2,
dtype=init_dtype)
else:
w1 = torch.randn(num_experts,
hidden_size,
shard_intermediate_size,
dtype=init_dtype)
w2 = torch.randn(num_experts,
shard_intermediate_size // 2,
hidden_size,
dtype=init_dtype)
gating_output = torch.randn(num_iters,
num_tokens,
num_experts,
......@@ -109,6 +137,7 @@ def benchmark_config(
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
use_nn_moe=nn_moe,
)
# JIT compilation & warmup
......@@ -116,15 +145,16 @@ def benchmark_config(
torch.cuda.synchronize()
# Capture 10 invocations with CUDA graph
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
for _ in range(10):
run()
torch.cuda.synchronize()
# graph = torch.cuda.CUDAGraph()
# with torch.cuda.graph(graph):
# for _ in range(10):
# run()
# torch.cuda.synchronize()
# Warmup
for _ in range(5):
graph.replay()
# graph.replay()
run()
torch.cuda.synchronize()
start_event = torch.cuda.Event(enable_timing=True)
......@@ -136,16 +166,17 @@ def benchmark_config(
torch.cuda.synchronize()
start_event.record()
graph.replay()
# graph.replay()
run()
end_event.record()
end_event.synchronize()
latencies.append(start_event.elapsed_time(end_event))
avg = sum(latencies) / (num_iters * 10) * 1000 # us
graph.reset()
# graph.reset()
return avg
def get_rocm_tuning_space(use_fp16):
def get_rocm_tuning_space(use_fp16, nn_moe: Optional[bool] = False):
block_mn_range = [16, 32, 64, 128, 256]
block_k_range = [16, 32, 64, 128, 256]
if not use_fp16:
......@@ -166,6 +197,9 @@ def get_rocm_tuning_space(use_fp16):
"num_stages": num_stage_range,
"waves_per_eu": waves_per_eu_range,
}
if nn_moe:
param_ranges["num_ldmatrixes"] = 1
if use_fp16:
param_ranges["matrix_instr_nonkdim"] = matrix_instr_nonkdim_range
param_ranges["kpack"] = kpack_range
......@@ -173,11 +207,11 @@ def get_rocm_tuning_space(use_fp16):
return param_ranges
def get_configs_compute_bound(use_fp16) -> List[Dict[str, int]]:
def get_configs_compute_bound(use_fp16, nn_moe: Optional[bool] = False) -> List[Dict[str, int]]:
configs: List[BenchmarkConfig] = []
if current_platform.is_rocm():
param_ranges = get_rocm_tuning_space(use_fp16)
param_ranges = get_rocm_tuning_space(use_fp16, nn_moe)
else:
# Reduced search space for faster tuning.
# TODO(woosuk): Increase the search space and use a performance model to
......@@ -370,6 +404,7 @@ class BenchmarkWorker:
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
search_space: List[Dict[str, int]],
nn_moe: Optional[bool] = False
) -> Dict[str, int]:
best_config = None
best_time = float("inf")
......@@ -392,7 +427,8 @@ class BenchmarkWorker:
dtype,
use_fp8_w8a8,
use_int8_w8a16,
num_iters=20)
num_iters=20,
nn_moe=nn_moe)
except triton.runtime.autotuner.OutOfResources:
# Some configurations may be invalid and fail to compile.
continue
......@@ -407,35 +443,63 @@ class BenchmarkWorker:
def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
return {
"BLOCK_SIZE_M":
config["BLOCK_SIZE_M"],
"BLOCK_SIZE_N":
config["BLOCK_SIZE_N"],
"BLOCK_SIZE_K":
config["BLOCK_SIZE_K"],
"GROUP_SIZE_M":
config["GROUP_SIZE_M"],
"num_warps":
config["num_warps"],
"num_stages":
config["num_stages"],
**({
if "num_ldmatrixes" not in config:
return {
"BLOCK_SIZE_M":
config["BLOCK_SIZE_M"],
"BLOCK_SIZE_N":
config["BLOCK_SIZE_N"],
"BLOCK_SIZE_K":
config["BLOCK_SIZE_K"],
"GROUP_SIZE_M":
config["GROUP_SIZE_M"],
"num_warps":
config["num_warps"],
"num_stages":
config["num_stages"],
**({
"waves_per_eu": config["waves_per_eu"]
} if "waves_per_eu" in config else {}),
**({
"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]
} if "matrix_instr_nonkdim" in config else {}),
**({
"kpack": config["kpack"]
} if "kpack" in config else {}),
}
} if "waves_per_eu" in config else {}),
**({
"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]
} if "matrix_instr_nonkdim" in config else {}),
**({
"kpack": config["kpack"]
} if "kpack" in config else {}),
}
else:
return {
"BLOCK_SIZE_M":
config["BLOCK_SIZE_M"],
"BLOCK_SIZE_N":
config["BLOCK_SIZE_N"],
"BLOCK_SIZE_K":
config["BLOCK_SIZE_K"],
"GROUP_SIZE_M":
config["GROUP_SIZE_M"],
"num_warps":
config["num_warps"],
"num_stages":
config["num_stages"],
"num_ldmatrixes":
config["num_ldmatrixes"],
**({
"waves_per_eu": config["waves_per_eu"]
} if "waves_per_eu" in config else {}),
**({
"matrix_instr_nonkdim": config["matrix_instr_nonkdim"]
} if "matrix_instr_nonkdim" in config else {}),
**({
"kpack": config["kpack"]
} if "kpack" in config else {}),
}
def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
shard_intermediate_size: int, hidden_size: int, topk: int,
dtype: torch.dtype, use_fp8_w8a8: bool,
use_int8_w8a16: bool) -> None:
use_int8_w8a16: bool, use_nn_moe: Optional[bool] = False) -> None:
dtype_str = get_config_dtype_str(dtype,
use_int8_w8a16=use_int8_w8a16,
use_fp8_w8a8=use_fp8_w8a8)
......@@ -443,7 +507,7 @@ def save_configs(configs: Dict[int, BenchmarkConfig], num_experts: int,
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
# is the intermediate size after silu_and_mul.
filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
dtype_str)
dtype_str, use_nn_moe=use_nn_moe)
print(f"Writing best config to {filename}...")
with open(filename, "w") as f:
......@@ -466,7 +530,7 @@ def main(args: argparse.Namespace):
topk = config.num_experts_per_tok
intermediate_size = config.intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
elif config.architectures[0] == "DeepseekV3ForCausalLM":
elif config.architectures[0] == "DeepseekV2ForCausalLM" or "DeepseekV3ForCausalLM":
E = config.n_routed_experts
topk = config.num_experts_per_tok
intermediate_size = config.moe_intermediate_size
......@@ -510,20 +574,20 @@ def main(args: argparse.Namespace):
if args.tune:
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
search_space = get_configs_compute_bound(is_fp16)
search_space = get_configs_compute_bound(is_fp16, args.nn_moe)
print(f"Start tuning over {len(search_space)} configurations...")
start = time.time()
configs = _distribute(
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space)
topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space, args.nn_moe)
for batch_size in batch_sizes])
best_configs = {
M: sort_config(config)
for M, config in zip(batch_sizes, configs)
}
save_configs(best_configs, E, shard_intermediate_size, hidden_size,
topk, dtype, use_fp8_w8a8, use_int8_w8a16)
topk, dtype, use_fp8_w8a8, use_int8_w8a16, use_nn_moe=args.nn_moe)
end = time.time()
print(f"Tuning took {end - start:.2f} seconds")
else:
......@@ -554,6 +618,7 @@ if __name__ == "__main__":
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--batch-size", type=int, required=False)
parser.add_argument("--tune", action="store_true")
parser.add_argument("--nn_moe", type=bool, default=True)
parser.add_argument("--trust-remote-code", action="store_true")
args = parser.parse_args()
......
......@@ -485,13 +485,9 @@ def get_version_add(sha: Optional[str] = None) -> str:
if sha != 'Unknown':
if sha is None:
sha = get_sha(vllm_root)
# if (major, minor) == ('2', '3'):
# version = 'das.opt1.' + sha[:7]
if (major, minor) == ('2', '4'):
version = 'das.opt1.' + sha[:7]
else:
# if (major, minor) == ('2', '3'):
# version = 'das.opt1'
if (major, minor) == ('2', '4'):
version = 'das.opt1'
......
......@@ -295,14 +295,27 @@ def fused_moe_kernel(
# Map program ids `pid` to the block of C it should compute.
# This is done in a grouped ordering to promote L2 data reuse.
pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
# num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
# num_pid_in_group = GROUP_SIZE_M * num_pid_n
# group_id = pid // num_pid_in_group
# first_pid_m = group_id * GROUP_SIZE_M
# group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
# pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
# pid_n = (pid % num_pid_in_group) // group_size_m
if GROUP_SIZE_M ==1:
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
pid_m = pid // num_pid_n
pid_n = pid % num_pid_n
else:
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
num_pid_in_group = GROUP_SIZE_M * num_pid_n
group_id = pid // num_pid_in_group
first_pid_m = group_id * GROUP_SIZE_M
group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
pid_n = (pid % num_pid_in_group) // group_size_m
# ----------------------------------------------------------
# Create pointers for the first blocks of A and B.
......@@ -479,7 +492,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_fp8_w8a8: bool,
use_int8_w8a16: bool,
use_int4_w4a16: bool,
block_shape: Optional[List[int]] = None) -> None:
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool]=False) -> None:
assert topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1
......@@ -510,7 +524,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
EM = min(sorted_token_ids.shape[0],
A.shape[0] * top_k * config['BLOCK_SIZE_M'])
grid = lambda META: (triton.cdiv(EM, META['BLOCK_SIZE_M']) * triton.cdiv(
B.shape[1], META['BLOCK_SIZE_N']), )
B.shape[1] if not use_nn_moe else B.shape[2], META['BLOCK_SIZE_N']), )
if (use_int8_w8a16 or use_int4_w4a16) and \
block_shape is not None and block_shape[1] > 0:
......@@ -566,15 +580,15 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
B.shape[1],
A.shape[1],
B.shape[1] if not use_nn_moe else B.shape[2],
A.shape[1] if not use_nn_moe else A.shape[2],
EM,
topk_ids.numel(),
A.stride(0),
A.stride(1),
B.stride(0),
B.stride(2),
B.stride(1),
B.stride(2) if not use_nn_moe else B.stride(1),
B.stride(1) if not use_nn_moe else B.stride(2),
C.stride(1),
C.stride(2),
A_scale.stride(0)
......@@ -602,12 +616,16 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
def get_config_file_name(E: int,
N: int,
dtype: Optional[str],
block_shape: Optional[List[int]] = None) -> str:
block_shape: Optional[List[int]] = None, use_nn_moe: Optional[bool] = False) -> str:
device_name = current_platform.get_device_name().replace(" ", "_")
dtype_selector = "" if not dtype else f",dtype={dtype}"
block_shape_selector = ("" if not block_shape or not all(block_shape) else
f",block_shape={block_shape}")
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501
f",block_shape={block_shape}")
if not use_nn_moe:
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}.json" # noqa: E501
else:
return f"E={E},N={N},device_name={device_name}{dtype_selector}{block_shape_selector}_nn.json"
# Adapted from: https://github.com/sgl-project/sglang/pull/2628
......@@ -618,6 +636,7 @@ def get_moe_configs(
dtype: Optional[str],
block_n: Optional[int] = None,
block_k: Optional[int] = None,
use_nn_moe: Optional[bool] = False
) -> Optional[Dict[int, Any]]:
"""
Return optimized configurations for the fused MoE kernel.
......@@ -631,7 +650,7 @@ def get_moe_configs(
# First look up if an optimized configuration is available in the configs
# directory
block_shape = [block_n, block_k] if block_n and block_k else None
json_file_name = get_config_file_name(E, N, dtype, block_shape)
json_file_name = get_config_file_name(E, N, dtype, block_shape, use_nn_moe=use_nn_moe)
config_file_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
......@@ -659,6 +678,7 @@ def get_default_config(
dtype: Optional[str],
is_marlin: bool,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool]=False
) -> Dict[str, int]:
if dtype == "fp8_w8a8" and block_shape is not None:
# Block-wise quant: BLOCK_SIZE_N must be divisible by block_shape[0]
......@@ -686,6 +706,8 @@ def get_default_config(
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
}
if use_nn_moe:
config["num_ldmatrixes"] = 1
return config
......@@ -697,6 +719,7 @@ def try_get_optimal_moe_config(
M: int,
is_marlin: bool = False,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False
):
from vllm.model_executor.layers.fused_moe import get_config
override_config = get_config()
......@@ -704,10 +727,13 @@ def try_get_optimal_moe_config(
config = override_config
else:
# First try to load optimal config from the file
E, _, N = w2_shape
if not use_nn_moe:
E, _, N = w2_shape
else:
E, N, _ = w2_shape
block_n = block_shape[0] if block_shape else 0
block_k = block_shape[1] if block_shape else 0
configs = get_moe_configs(E, N, dtype, block_n, block_k)
configs = get_moe_configs(E, N, dtype, block_n, block_k, use_nn_moe=use_nn_moe)
if configs:
# If an optimal configuration map has been found, look up the
......@@ -715,8 +741,8 @@ def try_get_optimal_moe_config(
config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
else:
# Else use the default config
config = get_default_config(M, E, N, w1_shape[2], top_k, dtype,
is_marlin, block_shape)
config = get_default_config(M, E, N, w1_shape[2] if not use_nn_moe else w1_shape[1], top_k, dtype,
is_marlin, block_shape, use_nn_moe=use_nn_moe)
return config
......@@ -843,10 +869,12 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> None:
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape)
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale, block_shape,
use_nn_moe)
def inplace_fused_experts_fake(
......@@ -864,7 +892,8 @@ def inplace_fused_experts_fake(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> None:
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> None:
pass
......@@ -891,11 +920,13 @@ def outplace_fused_experts(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor:
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, use_fp8_w8a8, use_int8_w8a16,
use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape)
a1_scale, a2_scale, block_shape,
use_nn_moe)
def outplace_fused_experts_fake(
......@@ -913,7 +944,8 @@ def outplace_fused_experts_fake(
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor:
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False) -> torch.Tensor:
return torch.empty_like(hidden_states)
......@@ -940,20 +972,23 @@ def fused_experts(hidden_states: torch.Tensor,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None):
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False):
if inplace:
torch.ops.vllm.inplace_fused_experts(hidden_states, w1, w2,
topk_weights, topk_ids,
use_fp8_w8a8, use_int8_w8a16,
use_int4_w4a16, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale,
a2_scale, block_shape)
a2_scale, block_shape,
use_nn_moe)
return hidden_states
else:
return torch.ops.vllm.outplace_fused_experts(
hidden_states, w1, w2, topk_weights, topk_ids, use_fp8_w8a8,
use_int8_w8a16, use_int4_w4a16, w1_scale, w2_scale, w1_zp, w2_zp,
a1_scale, a2_scale, block_shape)
a1_scale, a2_scale, block_shape,
use_nn_moe)
def fused_experts_impl(hidden_states: torch.Tensor,
......@@ -971,11 +1006,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
w2_zp: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None):
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False):
# Check constraints.
if use_int4_w4a16:
assert hidden_states.shape[1] // 2 == w1.shape[
2], "Hidden size mismatch"
elif use_nn_moe:
assert hidden_states.shape[1] == w1.shape[1], "Hidden size mismatch"
else:
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
......@@ -988,7 +1026,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
]
num_tokens, _ = hidden_states.shape
E, N, _ = w1.shape
if use_nn_moe:
E, _, N = w1.shape
else:
E, N, _ = w1.shape
# We execute the fused_moe kernel in chunks to circumvent this issue:
# https://github.com/vllm-project/vllm/issues/5938
CHUNK_SIZE = envs.VLLM_FUSED_MOE_CHUNK_SIZE
......@@ -1005,6 +1046,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
topk_ids.shape[1],
config_dtype,
block_shape=block_shape,
use_nn_moe=use_nn_moe,
)
config = get_config_func(M)
......@@ -1015,7 +1057,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
device=hidden_states.device,
dtype=hidden_states.dtype)
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1] if not use_nn_moe else w2.shape[2]),
device=hidden_states.device,
dtype=hidden_states.dtype)
......@@ -1077,7 +1119,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
block_shape=block_shape)
block_shape=block_shape,
use_nn_moe=use_nn_moe)
torch.ops._C.silu_and_mul(intermediate_cache2,
intermediate_cache1.view(-1, N))
......@@ -1100,7 +1143,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
block_shape=block_shape)
block_shape=block_shape,
use_nn_moe=use_nn_moe)
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx])
......@@ -1129,6 +1173,7 @@ def fused_moe(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool] = False,
) -> torch.Tensor:
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
......@@ -1200,4 +1245,5 @@ def fused_moe(
w2_zp=w2_zp,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape)
block_shape=block_shape,
use_nn_moe=use_nn_moe)
import os
from abc import abstractmethod
from enum import Enum
from typing import Callable, List, Optional, Tuple
......@@ -66,24 +67,41 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
params_dtype: torch.dtype, use_nn_moe: bool,
**extra_weight_attrs):
# Fused gate_up_proj (column parallel)
w13_weight = torch.nn.Parameter(torch.empty(
if not use_nn_moe:
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
2 * intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
else:
w13_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
2 * intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
# down_proj (row parallel)
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
if not use_nn_moe:
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
hidden_size,
intermediate_size_per_partition,
dtype=params_dtype),
requires_grad=False)
else:
w2_weight = torch.nn.Parameter(torch.empty(
num_experts,
intermediate_size_per_partition,
hidden_size,
dtype=params_dtype),
requires_grad=False)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
......@@ -113,7 +131,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
e_score_correction_bias: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False
) -> torch.Tensor:
return self.forward(x=x,
layer=layer,
......@@ -125,7 +144,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group=num_expert_group,
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
use_nn_moe=use_nn_moe)
def forward_cuda(
self,
......@@ -139,7 +159,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
e_score_correction_bias: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False
) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
......@@ -158,7 +179,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True)
inplace=True,
use_nn_moe=use_nn_moe)
def forward_cpu(
self,
......@@ -171,6 +193,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
topk_group: Optional[int] = None,
num_expert_group: Optional[int] = None,
custom_routing_function: Optional[Callable] = None,
use_nn_moe: Optional[bool] = False,
**kwargs,
):
assert custom_routing_function is None
......@@ -298,12 +321,19 @@ class FusedMoE(torch.nn.Module):
self.intermediate_size_per_partition,
"params_dtype": params_dtype,
"weight_loader": self.weight_loader,
"use_nn_moe":self.use_nn_moe,
}
# need full intermediate size pre-sharding for WNA16 act order
if (self.quant_method.__class__.__name__ ==
"CompressedTensorsWNA16MoEMethod"):
moe_quant_params["intermediate_size_full"] = intermediate_size
if quant_config is None:
# Not considering quant for now, temporarily
self.use_nn_moe = int(os.environ.get('MOE_NN', 1)) == 1
else:
self.use_nn_moe = False
self.quant_method.create_weights(layer=self, **moe_quant_params)
def _load_per_tensor_weight_scale(self, shard_id: str,
......@@ -372,7 +402,8 @@ class FusedMoE(torch.nn.Module):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2
loaded_weight = loaded_weight.narrow(shard_dim, shard_size * tp_rank,
loaded_weight = loaded_weight.narrow(shard_dim if not self.use_nn_moe else ~shard_dim,
shard_size * tp_rank,
shard_size)
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
......@@ -382,7 +413,10 @@ class FusedMoE(torch.nn.Module):
else:
assert shard_id == "w3"
expert_data = expert_data.narrow(shard_dim, shard_size, shard_size)
expert_data.copy_(loaded_weight)
if not self.use_nn_moe:
expert_data.copy_(loaded_weight)
else:
expert_data.copy_(loaded_weight.T)
def _load_w2(self,
expert_data: torch.Tensor,
......@@ -396,18 +430,24 @@ class FusedMoE(torch.nn.Module):
# Narrow parameter and load.
shard_size = expert_data.shape[shard_dim]
if not load_full:
loaded_weight = loaded_weight.narrow(shard_dim,
loaded_weight = loaded_weight.narrow(shard_dim if not self.use_nn_moe else ~shard_dim,
shard_size * tp_rank,
shard_size)
# w2, down_proj: Load into only logical weight of w2.
expert_data.copy_(loaded_weight)
if not self.use_nn_moe:
expert_data.copy_(loaded_weight)
else:
expert_data.copy_(loaded_weight.T)
def _load_single_value(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, expert_id: int):
param_data = param.data
# Input scales can be loaded directly and should be equal.
param_data[expert_id] = loaded_weight
if not self.use_nn_moe:
param_data[expert_id] = loaded_weight
else:
param_data[expert_id] = loaded_weight.T
def _load_g_idx(self, shard_id: str, expert_data: torch.Tensor,
shard_dim: int, loaded_weight: torch.Tensor, tp_rank: int):
......@@ -419,7 +459,10 @@ class FusedMoE(torch.nn.Module):
tp_rank=tp_rank)
else:
assert shard_id in ("w1", "w3")
expert_data.copy_(loaded_weight)
if not self.use_nn_moe:
expert_data.copy_(loaded_weight)
else:
expert_data.copy_(loaded_weight.T)
def weight_loader(self, param: torch.nn.Parameter,
loaded_weight: torch.Tensor, weight_name: str,
......@@ -450,7 +493,7 @@ class FusedMoE(torch.nn.Module):
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size_per_partition is
is_transposed = getattr(param, "is_transposed", False)
is_transposed = getattr(param, "is_transposed", False) or self.use_nn_moe
shard_dim = SHARD_ID_TO_SHARDED_DIM[shard_id]
if is_transposed:
shard_dim = int(not shard_dim)
......@@ -592,7 +635,8 @@ class FusedMoE(torch.nn.Module):
num_expert_group=self.num_expert_group,
custom_routing_function=self.custom_routing_function,
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias)
e_score_correction_bias=self.e_score_correction_bias,
use_nn_moe=self.use_nn_moe)
if self.reduce_results and self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(
......
......@@ -28,7 +28,8 @@ def get_model_architecture(
architectures = getattr(model_config.hf_config, "architectures", [])
visions = getattr(model_config.hf_config, "visual", []) or getattr(model_config.hf_config, "vision_config", [])
# 'Qwen2VLForConditionalGeneration'
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'ChatGLMModel', 'ChatGLMForConditionalGeneration', 'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM', 'MLPSpeculatorPreTrainedModel', 'FalconForCausalLM']
support_nn_architectures = ['LlamaForCausalLM', 'QWenLMHeadModel', 'Qwen2ForCausalLM', 'Qwen2MoeForCausalLM', 'ChatGLMModel', 'ChatGLMForConditionalGeneration',
'BaichuanForCausalLM', 'BloomForCausalLM', 'MedusaModel', 'MixtralForCausalLM', 'MLPSpeculatorPreTrainedModel', 'FalconForCausalLM', 'DeepseekV3ForCausalLM']
if any(arch in architectures for arch in support_nn_architectures):
if os.getenv('LLAMA_NN') != '0':
if (architectures == ['QWenLMHeadModel'] or architectures == ['ChatGLMModel'] ) and visions != []:
......
......@@ -20,6 +20,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Inference-only DeepseekV3 model."""
import os
import re
from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
import torch
......@@ -52,6 +54,7 @@ from .interfaces import SupportsPP
from .utils import (PPMissingLayer, is_pp_missing_parameter,
make_empty_intermediate_tensors_factory, make_layers,
maybe_prefix)
from vllm import _custom_ops as ops
class DeepseekV3MLP(nn.Module):
......@@ -666,6 +669,15 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP):
self.sampler = get_sampler()
self.make_empty_intermediate_tensors = (
self.model.make_empty_intermediate_tensors)
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_mla = False
if hasattr(vllm_config.model_config, "use_mla"):
self.use_mla = vllm_config.model_config.use_mla
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)
......@@ -800,4 +812,42 @@ class DeepseekV3ForCausalLM(nn.Module, SupportsPP):
default_weight_loader)
weight_loader(param, loaded_weight)
loaded_params.add(name)
if self.use_llama_nn and self.quant_method is None:
lay_key_words = [
"self_attn.q_a_proj.weight",
"self_attn.kv_a_proj_with_mqa.weight",
"mlp.gate.weight",
"mlp.gate_up_proj.weight",
"mlp.down_proj",
"shared_experts.gate_up_proj",
"shared_experts.down_proj"
]
if not self.use_mla:
lay_key_words.extend([
"self_attn.q_proj.weight",
"self_attn.q_b_proj.weight",
"self_attn.kv_b_proj.weight",
"self_attn.o_proj.weight",
])
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
if "lm_head.weight" in layername:
lay_key_words.append("lm_head.weight")
combined_words = "|".join(lay_key_words)
os.environ['LM_NN'] = '1'
else:
os.environ['LM_NN'] = '0'
matches = re.findall(combined_words, layername)
if matches:
_weight = torch.zeros_like(weight.data)
ori_shape =_weight.shape
ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1])
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1],-1)
return loaded_params
......@@ -521,6 +521,13 @@ class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
if "lm_head.weight" in layername:
lay_key_words.append("lm_head.weight")
combined_words = "|".join(lay_key_words)
os.environ['LM_NN'] = '1'
else:
os.environ['LM_NN'] = '0'
matches = re.findall(combined_words, layername)
if matches:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment