Unverified Commit 5f65e2b8 authored by HAI's avatar HAI Committed by GitHub
Browse files

[Performance, Hardware] MoE weights padding to AMD MI300x GPUs (#1836)

parent 4e2af03c
......@@ -14,6 +14,7 @@ from vllm import _custom_ops as ops
from vllm.logger import init_logger
logger = init_logger(__name__)
padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
@triton.jit
......@@ -263,7 +264,7 @@ def invoke_fused_moe_kernel(
expert_ids,
num_tokens_post_padded,
B.shape[1],
B.shape[2],
B.shape[2] - padding_size,
sorted_token_ids.shape[0],
topk_ids.numel(),
A.stride(0),
......@@ -464,7 +465,7 @@ def fused_experts(
a2_scale: Optional[torch.Tensor] = None,
):
# Check constraints.
assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
assert hidden_states.shape[1] == w1.shape[2] - padding_size, "Hidden size mismatch"
assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
assert w1.is_contiguous(), "Expert weights1 must be contiguous"
......@@ -481,7 +482,7 @@ def fused_experts(
get_config_func = functools.partial(
try_get_optimal_moe_config,
w1.shape,
w2.shape,
(w2.shape[0], w2.shape[1], w2.shape[2] - padding_size),
topk_ids.shape[1],
"float8" if use_fp8 else None,
override_config=override_config,
......
# Adapted from
# https://github.com/vllm-project/vllm/tree/v0.5.4/vllm/model_executor/layers/fused_moe
import os
from abc import abstractmethod
from typing import List, Optional, Tuple
import torch
import torch.nn.functional as F
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
......@@ -18,6 +20,7 @@ from vllm.model_executor.layers.quantization.base_config import (
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.utils import set_weight_attrs
from sglang.srt.layers.fused_moe.fused_moe import padding_size
from sglang.srt.utils import is_hip
logger = init_logger(__name__)
......@@ -506,6 +509,19 @@ class Fp8MoEMethod(FusedMoEMethodBase):
)
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
# If ROCm, apply weight padding (min. Mem channel contention) only if set
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
return
# If checkpoint is fp8, we need to handle that the
......@@ -572,6 +588,18 @@ class Fp8MoEMethod(FusedMoEMethodBase):
start += shard_size
layer.w13_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
# If ROCm, apply weight padding (min. Mem channel contention) only if set
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
return
def apply(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment