Unverified Commit e808c1df authored by kk's avatar kk Committed by GitHub
Browse files

Integrate ROCm ater package for ck moe function feasibility (#2854)


Co-authored-by: default avatarwunhuang <wunhuang@amd.com>
Co-authored-by: default avatarLin, Soga <soga.lin@amd.com>
parent a18ab81d
...@@ -16,6 +16,10 @@ ARG SGL_BRANCH=${SGL_DEFAULT} ...@@ -16,6 +16,10 @@ ARG SGL_BRANCH=${SGL_DEFAULT}
ARG TRITON_REPO="https://github.com/triton-lang/triton.git" ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
ARG TRITON_COMMIT="845d75a" ARG TRITON_COMMIT="845d75a"
ARG ATER_REPO="https://github.com/HaiShaw/ater"
ARG CK_COMMITS="fa05ae"
RUN git clone ${SGL_REPO} \ RUN git clone ${SGL_REPO} \
&& cd sglang \ && cd sglang \
&& if [ "${SGL_BRANCH}" = ${SGL_DEFAULT} ]; then \ && if [ "${SGL_BRANCH}" = ${SGL_DEFAULT} ]; then \
...@@ -46,6 +50,11 @@ RUN git clone ${TRITON_REPO} \ ...@@ -46,6 +50,11 @@ RUN git clone ${TRITON_REPO} \
&& cd python \ && cd python \
&& python3 setup.py install && python3 setup.py install
RUN git clone ${ATER_REPO} \
&& cd ater \
&& git submodule update --init --recursive \
&& PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop
# Performance environment variable. # Performance environment variable.
ENV HIP_FORCE_DEV_KERNARG=1 ENV HIP_FORCE_DEV_KERNARG=1
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
import os
from abc import abstractmethod from abc import abstractmethod
from enum import Enum from enum import Enum
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple
...@@ -18,7 +19,7 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -18,7 +19,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import is_hip, permute_weight, set_weight_attrs
if torch.cuda.is_available(): if torch.cuda.is_available():
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
...@@ -97,6 +98,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -97,6 +98,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.register_parameter("w2_weight", w2_weight) layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if is_hip() and bool(int(os.getenv("CK_MOE", "0"))):
layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
permute_weight(layer.w2_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
return
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -148,14 +163,26 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -148,14 +163,26 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
correction_bias=correction_bias, correction_bias=correction_bias,
) )
return fused_experts( if is_hip() and bool(int(os.getenv("CK_MOE", "0"))):
hidden_states=x, import ater
w1=layer.w13_weight, from ater.fused_moe import fused_experts_ck
w2=layer.w2_weight,
topk_weights=topk_weights, return fused_experts_ck(
topk_ids=topk_ids, hidden_states=x,
inplace=True, w1=layer.w13_weight,
) w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
else:
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
)
def forward_cpu(self, *args, **kwargs): def forward_cpu(self, *args, **kwargs):
raise NotImplementedError("The CPU backend currently does not support MoE.") raise NotImplementedError("The CPU backend currently does not support MoE.")
......
...@@ -40,6 +40,7 @@ from sglang.srt.layers.quantization.fp8_utils import ( ...@@ -40,6 +40,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
from sglang.srt.utils import ( from sglang.srt.utils import (
get_bool_env_var, get_bool_env_var,
is_hip, is_hip,
permute_weight,
print_warning_once, print_warning_once,
set_weight_attrs, set_weight_attrs,
) )
...@@ -616,18 +617,30 @@ class Fp8MoEMethod: ...@@ -616,18 +617,30 @@ class Fp8MoEMethod:
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
# If ROCm, apply weight padding (min. Mem channel contention) only if set if is_hip():
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))): if bool(int(os.getenv("CK_MOE", "0"))):
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), permute_weight(layer.w13_weight.data),
requires_grad=False, requires_grad=False,
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter( layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), permute_weight(layer.w2_weight.data),
requires_grad=False, requires_grad=False,
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
elif bool(int(os.getenv("MOE_PADDING", "0"))):
# If ROCm, apply weight padding (min. Mem channel contention) only if set
layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
return return
# If checkpoint is fp8, we need to handle that the # If checkpoint is fp8, we need to handle that the
...@@ -708,18 +721,30 @@ class Fp8MoEMethod: ...@@ -708,18 +721,30 @@ class Fp8MoEMethod:
max_w13_scales, requires_grad=False max_w13_scales, requires_grad=False
) )
# If ROCm, apply weight padding (min. Mem channel contention) only if set if is_hip():
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))): if bool(int(os.getenv("CK_MOE", "0"))):
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), permute_weight(layer.w13_weight.data),
requires_grad=False, requires_grad=False,
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter( layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0), permute_weight(layer.w2_weight.data),
requires_grad=False, requires_grad=False,
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
elif bool(int(os.getenv("MOE_PADDING", "0"))):
# If ROCm, apply weight padding (min. Mem channel contention) only if set
layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
F.pad(layer.w2_weight.data, (0, padding_size), "constant", 0),
requires_grad=False,
)
torch.cuda.empty_cache()
return return
def apply( def apply(
...@@ -752,27 +777,55 @@ class Fp8MoEMethod: ...@@ -752,27 +777,55 @@ class Fp8MoEMethod:
correction_bias=correction_bias, correction_bias=correction_bias,
) )
# Expert fusion with FP8 quantization if is_hip() and bool(int(os.getenv("CK_MOE", "0"))):
return fused_experts( import ater
x, from ater.fused_moe import fused_experts_ck
layer.w13_weight,
layer.w2_weight, return fused_experts_ck(
topk_weights=topk_weights, x,
topk_ids=topk_ids, layer.w13_weight,
inplace=True, layer.w2_weight,
use_fp8_w8a8=True, topk_weights=topk_weights,
w1_scale=( topk_ids=topk_ids,
layer.w13_weight_scale_inv use_fp8_w8a8=True,
if self.block_quant w1_scale=(
else layer.w13_weight_scale layer.w13_weight_scale_inv
), if self.block_quant
w2_scale=( else layer.w13_weight_scale
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale ),
), w2_scale=(
a1_scale=layer.w13_input_scale, layer.w2_weight_scale_inv
a2_scale=layer.w2_input_scale, if self.block_quant
block_shape=self.quant_config.weight_block_size, else layer.w2_weight_scale
) ),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
else:
# Expert fusion with FP8 quantization
return fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
use_fp8_w8a8=True,
w1_scale=(
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
),
w2_scale=(
layer.w2_weight_scale_inv
if self.block_quant
else layer.w2_weight_scale
),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
block_shape=self.quant_config.weight_block_size,
)
class Fp8KVCacheMethod(BaseKVCacheMethod): class Fp8KVCacheMethod(BaseKVCacheMethod):
......
...@@ -1340,6 +1340,25 @@ def parse_tool_response(text, tools, **kwargs): ...@@ -1340,6 +1340,25 @@ def parse_tool_response(text, tools, **kwargs):
return text, call_info_list return text, call_info_list
def permute_weight(x: torch.Tensor) -> torch.Tensor:
b_ = x.shape[0]
n_ = x.shape[1]
k_ = x.shape[2]
x_ = x
if x.dtype == torch.bfloat16 or x.dtype == torch.float16:
x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 32), 4, 8)
elif x.dtype == torch.float8_e4m3fnuz or x.dtype == torch.int8:
x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 64), 4, 16)
else:
return x_
x_ = x_.permute(0, 1, 3, 4, 2, 5)
x_ = x_.contiguous()
x_ = x_.view(*x.shape)
return x_
class MultiprocessingSerializer: class MultiprocessingSerializer:
@staticmethod @staticmethod
def serialize(obj): def serialize(obj):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment