Unverified Commit 5f65e2b8 authored by HAI's avatar HAI Committed by GitHub
Browse files

[Performance, Hardware] MoE weights padding to AMD MI300x GPUs (#1836)

parent 4e2af03c
...@@ -14,6 +14,7 @@ from vllm import _custom_ops as ops ...@@ -14,6 +14,7 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
logger = init_logger(__name__) logger = init_logger(__name__)
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
@triton.jit @triton.jit
...@@ -263,7 +264,7 @@ def invoke_fused_moe_kernel( ...@@ -263,7 +264,7 @@ def invoke_fused_moe_kernel(
expert_ids, expert_ids,
num_tokens_post_padded, num_tokens_post_padded,
B.shape[1], B.shape[1],
B.shape[2], B.shape[2] - padding_size,
sorted_token_ids.shape[0], sorted_token_ids.shape[0],
topk_ids.numel(), topk_ids.numel(),
A.stride(0), A.stride(0),
...@@ -464,7 +465,7 @@ def fused_experts( ...@@ -464,7 +465,7 @@ def fused_experts(
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
): ):
# Check constraints. # Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch" assert hidden_states.shape[1] == w1.shape[2] - padding_size, "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch" assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous" assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous" assert w1.is_contiguous(), "Expert weights1 must be contiguous"
...@@ -481,7 +482,7 @@ def fused_experts( ...@@ -481,7 +482,7 @@ def fused_experts(
get_config_func = functools.partial( get_config_func = functools.partial(
try_get_optimal_moe_config, try_get_optimal_moe_config,
w1.shape, w1.shape,
w2.shape, (w2.shape[0], w2.shape[1], w2.shape[2] - padding_size),
topk_ids.shape[1], topk_ids.shape[1],
"float8" if use_fp8 else None, "float8" if use_fp8 else None,
override_config=override_config, override_config=override_config,
......
# Adapted from # Adapted from
# https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe # https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe
import os
from abc import abstractmethod from abc import abstractmethod
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
import torch import torch
import torch.nn.functional as F
from vllm.distributed import ( from vllm.distributed import (
get_tensor_model_parallel_rank, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
...@@ -18,6 +20,7 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -18,6 +20,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from sglang.srt.layers.fused_moe.fused_moe import padding_size
from sglang.srt.utils import is_hip from sglang.srt.utils import is_hip
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -506,6 +509,19 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -506,6 +509,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
) )
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
# If ROCm, apply weight padding (min. Mem channel contention) only if set
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
return return
# If checkpoint is fp8, we need to handle that the # If checkpoint is fp8, we need to handle that the
...@@ -572,6 +588,18 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -572,6 +588,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
start += shard_size start += shard_size
layer.w13_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False) layer.w13_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
# If ROCm, apply weight padding (min. Mem channel contention) only if set
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
return return
def apply( def apply(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment