Unverified Commit 83239ff1 authored by Michael Goin's avatar Michael Goin Committed by GitHub
Browse files

Add thread_n=64 support to Marlin MoE (#32360)


Signed-off-by: default avatarmgoin <mgoin64@gmail.com>
parent c277fbdf
...@@ -58,7 +58,7 @@ TEMPLATE = ( ...@@ -58,7 +58,7 @@ TEMPLATE = (
"( MARLIN_KERNEL_PARAMS );" "( MARLIN_KERNEL_PARAMS );"
) )
THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128)] THREAD_CONFIGS = [(128, 128, 256), (64, 256, 256), (64, 128, 128), (128, 64, 128)]
THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4] THREAD_M_BLOCKS = [0.5, 1, 2, 3, 4]
......
...@@ -126,14 +126,16 @@ thread_config_t small_batch_thread_configs[] = { ...@@ -126,14 +126,16 @@ thread_config_t small_batch_thread_configs[] = {
// thread_k, thread_n, num_threads // thread_k, thread_n, num_threads
{128, 128, 256}, {128, 128, 256},
{64, 128, 128}}; {64, 128, 128},
{128, 64, 128}};
thread_config_t large_batch_thread_configs[] = { thread_config_t large_batch_thread_configs[] = {
// Ordered by priority // Ordered by priority
// thread_k, thread_n, num_threads // thread_k, thread_n, num_threads
{64, 256, 256}, {64, 256, 256},
{64, 128, 128}}; {64, 128, 128},
{128, 64, 128}};
typedef struct { typedef struct {
int blocks_per_sm; int blocks_per_sm;
......
...@@ -226,6 +226,7 @@ def prepare_fp8_moe_layer_for_marlin( ...@@ -226,6 +226,7 @@ def prepare_fp8_moe_layer_for_marlin(
e = layer.num_experts e = layer.num_experts
k = layer.hidden_size k = layer.hidden_size
n = layer.intermediate_size_per_partition n = layer.intermediate_size_per_partition
w13_n = w13_weight.size(1)
weight_block_size = getattr(layer, "weight_block_size", None) weight_block_size = getattr(layer, "weight_block_size", None)
# WORKSPACE # WORKSPACE
...@@ -240,7 +241,7 @@ def prepare_fp8_moe_layer_for_marlin( ...@@ -240,7 +241,7 @@ def prepare_fp8_moe_layer_for_marlin(
def repack_weight(name: str, weight: torch.Tensor) -> torch.Tensor: def repack_weight(name: str, weight: torch.Tensor) -> torch.Tensor:
tensor_list = [] tensor_list = []
if "w13" in name: if "w13" in name:
size_n, size_k = n * 2, k size_n, size_k = w13_n, k
else: else:
size_n, size_k = k, n size_n, size_k = k, n
...@@ -268,7 +269,7 @@ def prepare_fp8_moe_layer_for_marlin( ...@@ -268,7 +269,7 @@ def prepare_fp8_moe_layer_for_marlin(
scales = scales.to(layer.orig_dtype) scales = scales.to(layer.orig_dtype)
tensor_list = [] tensor_list = []
if "w13" in name: if "w13" in name:
size_n, size_k = n * 2, k size_n, size_k = w13_n, k
else: else:
size_n, size_k = k, n size_n, size_k = k, n
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment