Commit 6adf9d12 authored by flyingdown's avatar flyingdown
Browse files

use tunning w4a16 moe

parent 54e03934
...@@ -36,6 +36,7 @@ try: ...@@ -36,6 +36,7 @@ try:
per_token_group_quant_int8, per_token_quant_int8) per_token_group_quant_int8, per_token_quant_int8)
from lmslim.layers.fused_moe.fuse_moe_int8 import (fused_experts_impl_int8, get_w8a8moe_json) from lmslim.layers.fused_moe.fuse_moe_int8 import (fused_experts_impl_int8, get_w8a8moe_json)
from lmslim.layers.fused_moe.fuse_moe_w4a8 import fused_experts_impl_w4a8 from lmslim.layers.fused_moe.fuse_moe_w4a8 import fused_experts_impl_w4a8
from lmslim.layers.fused_moe.fuse_moe_w4a16 import get_moe_triton_config_w4a16
except Exception: except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n") print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
...@@ -1984,7 +1985,15 @@ def fused_experts_impl( ...@@ -1984,7 +1985,15 @@ def fused_experts_impl(
use_nn_moe=use_nn_moe, use_nn_moe=use_nn_moe,
) )
config = get_config_func(M) # config = get_config_func(M)
_, N1, _ = w1.shape
_, N2, K2 = w2.shape
config, _, status = get_moe_triton_config_w4a16(
M, E, N1, N2, K2 * 2, top_k_num, block_shape[1], hidden_states.dtype
)
# debug
# print(f"M:{M}, E:{E}, N1:{N1}, N2:{N2}, K2:{K2}, top_k_num:{top_k_num}, block_shape:{block_shape}, dtype:{hidden_states.dtype}, status:{status}")
assert status, "config not found."
# We can reuse the memory between these because by the time we need # We can reuse the memory between these because by the time we need
# cache3, we're done with cache1 # cache3, we're done with cache1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment