Unverified Commit e808c1df authored by kk's avatar kk Committed by GitHub
Browse files

Integrate ROCm ater package for ck moe function feasibility (#2854)


Co-authored-by: default avatarwunhuang <wunhuang@amd.com>
Co-authored-by: default avatarLin, Soga <soga.lin@amd.com>
parent a18ab81d
......@@ -16,6 +16,10 @@ ARG SGL_BRANCH=${SGL_DEFAULT}
ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
ARG TRITON_COMMIT="845d75a"
ARG ATER_REPO="https://github.com/HaiShaw/ater"
ARG CK_COMMITS="fa05ae"
RUN git clone ${SGL_REPO} \
&& cd sglang \
&& if [ "${SGL_BRANCH}" = ${SGL_DEFAULT} ]; then \
......@@ -46,6 +50,11 @@ RUN git clone ${TRITON_REPO} \
&& cd python \
&& python3 setup.py install
RUN git clone ${ATER_REPO} \
&& cd ater \
&& git submodule update --init --recursive \
&& PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop
# Performance environment variable.
ENV HIP_FORCE_DEV_KERNARG=1
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
import os
from abc import abstractmethod
from enum import Enum
from typing import Callable, List, Optional, Tuple
......@@ -18,7 +19,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.utils import set_weight_attrs
from sglang.srt.utils import is_hip, permute_weight, set_weight_attrs
if torch.cuda.is_available():
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
......@@ -97,6 +98,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if is_hip() and bool(int(os.getenv("CK_MOE", "0"))):
layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
permute_weight(layer.w2_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
return
def apply(
self,
layer: torch.nn.Module,
......@@ -148,14 +163,26 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
correction_bias=correction_bias,
)
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
)
if is_hip() and bool(int(os.getenv("CK_MOE", "0"))):
import ater
from ater.fused_moe import fused_experts_ck
return fused_experts_ck(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
else:
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
)
def forward_cpu(self, *args, **kwargs):
raise NotImplementedError("The CPU backend currently does not support MoE.")
......
......@@ -40,6 +40,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
from sglang.srt.utils import (
get_bool_env_var,
is_hip,
permute_weight,
print_warning_once,
set_weight_attrs,
)
......@@ -616,18 +617,30 @@ class Fp8MoEMethod:
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
# If ROCm, apply weight padding (min. Mem channel contention) only if set
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
if is_hip():
if bool(int(os.getenv("CK_MOE", "0"))):
layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
permute_weight(layer.w2_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
elif bool(int(os.getenv("MOE_PADDING", "0"))):
# If ROCm, apply weight padding (min. Mem channel contention) only if set
layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
return
# If checkpoint is fp8, we need to handle that the
......@@ -708,18 +721,30 @@ class Fp8MoEMethod:
max_w13_scales, requires_grad=False
)
# If ROCm, apply weight padding (min. Mem channel contention) only if set
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
if is_hip():
if bool(int(os.getenv("CK_MOE", "0"))):
layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
permute_weight(layer.w2_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
elif bool(int(os.getenv("MOE_PADDING", "0"))):
# If ROCm, apply weight padding (min. Mem channel contention) only if set
layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
return
def apply(
......@@ -752,27 +777,55 @@ class Fp8MoEMethod:
correction_bias=correction_bias,
)
# Expert fusion with FP8 quantization
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_fp8_w8a8=True,
w1_scale=(
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
),
w2_scale=(
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
)
if is_hip() and bool(int(os.getenv("CK_MOE", "0"))):
import ater
from ater.fused_moe import fused_experts_ck
return fused_experts_ck(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_fp8_w8a8=True,
w1_scale=(
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
),
w2_scale=(
layer.w2_weight_scale_inv
if self.block_quant
else layer.w2_weight_scale
),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
else:
# Expert fusion with FP8 quantization
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_fp8_w8a8=True,
w1_scale=(
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
),
w2_scale=(
layer.w2_weight_scale_inv
if self.block_quant
else layer.w2_weight_scale
),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
)
class Fp8KVCacheMethod(BaseKVCacheMethod):
......
......@@ -1340,6 +1340,25 @@ def parse_tool_response(text, tools, **kwargs):
return text, call_info_list
def permute_weight(x: torch.Tensor) -> torch.Tensor:
b_ = x.shape[0]
n_ = x.shape[1]
k_ = x.shape[2]
x_ = x
if x.dtype == torch.bfloat16 or x.dtype == torch.float16:
x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 32), 4, 8)
elif x.dtype == torch.float8_e4m3fnuz or x.dtype == torch.int8:
x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 64), 4, 16)
else:
return x_
x_ = x_.permute(0, 1, 3, 4, 2, 5)
x_ = x_.contiguous()
x_ = x_.view(*x.shape)
return x_
class MultiprocessingSerializer:
@staticmethod
def serialize(obj):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment