Unverified Commit e808c1df authored by kk's avatar kk Committed by GitHub
Browse files

Integrate ROCm ater package for ck moe function feasibility (#2854)


Co-authored-by: default avatarwunhuang <wunhuang@amd.com>
Co-authored-by: default avatarLin, Soga <soga.lin@amd.com>
parent a18ab81d
...@@ -16,6 +16,10 @@ ARG SGL_BRANCH=${SGL_DEFAULT} ...@@ -16,6 +16,10 @@ ARG SGL_BRANCH=${SGL_DEFAULT}
ARG TRITON_REPO="https://github.com/triton-lang/triton.git" ARG TRITON_REPO="https://github.com/triton-lang/triton.git"
ARG TRITON_COMMIT="845d75a" ARG TRITON_COMMIT="845d75a"
ARG ATER_REPO="https://github.com/HaiShaw/ater"
ARG CK_COMMITS="fa05ae"
RUN git clone ${SGL_REPO} \ RUN git clone ${SGL_REPO} \
&& cd sglang \ && cd sglang \
&& if [ "${SGL_BRANCH}" = ${SGL_DEFAULT} ]; then \ && if [ "${SGL_BRANCH}" = ${SGL_DEFAULT} ]; then \
...@@ -46,6 +50,11 @@ RUN git clone ${TRITON_REPO} \ ...@@ -46,6 +50,11 @@ RUN git clone ${TRITON_REPO} \
&& cd python \ && cd python \
&& python3 setup.py install && python3 setup.py install
RUN git clone ${ATER_REPO} \
&& cd ater \
&& git submodule update --init --recursive \
&& PREBUILD_KERNELS=1 GPU_ARCHS=gfx942 python3 setup.py develop
# Performance environment variable. # Performance environment variable.
ENV HIP_FORCE_DEV_KERNARG=1 ENV HIP_FORCE_DEV_KERNARG=1
......
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
import os
from abc import abstractmethod from abc import abstractmethod
from enum import Enum from enum import Enum
from typing import Callable, List, Optional, Tuple from typing import Callable, List, Optional, Tuple
...@@ -18,7 +19,7 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -18,7 +19,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.utils import set_weight_attrs from sglang.srt.utils import is_hip, permute_weight, set_weight_attrs
if torch.cuda.is_available(): if torch.cuda.is_available():
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
...@@ -97,6 +98,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -97,6 +98,20 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.register_parameter("w2_weight", w2_weight) layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if is_hip() and bool(int(os.getenv("CK_MOE", "0"))):
layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
permute_weight(layer.w2_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
return
def apply( def apply(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -148,6 +163,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -148,6 +163,18 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
correction_bias=correction_bias, correction_bias=correction_bias,
) )
if is_hip() and bool(int(os.getenv("CK_MOE", "0"))):
import ater
from ater.fused_moe import fused_experts_ck
return fused_experts_ck(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
)
else:
return fused_experts( return fused_experts(
hidden_states=x, hidden_states=x,
w1=layer.w13_weight, w1=layer.w13_weight,
......
...@@ -40,6 +40,7 @@ from sglang.srt.layers.quantization.fp8_utils import ( ...@@ -40,6 +40,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
from sglang.srt.utils import ( from sglang.srt.utils import (
get_bool_env_var, get_bool_env_var,
is_hip, is_hip,
permute_weight,
print_warning_once, print_warning_once,
set_weight_attrs, set_weight_attrs,
) )
...@@ -616,8 +617,20 @@ class Fp8MoEMethod: ...@@ -616,8 +617,20 @@ class Fp8MoEMethod:
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False) layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False) layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
if is_hip():
if bool(int(os.getenv("CK_MOE", "0"))):
layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
permute_weight(layer.w2_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
elif bool(int(os.getenv("MOE_PADDING", "0"))):
# If ROCm, apply weight padding (min. Mem channel contention) only if set # If ROCm, apply weight padding (min. Mem channel contention) only if set
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False, requires_grad=False,
...@@ -708,8 +721,20 @@ class Fp8MoEMethod: ...@@ -708,8 +721,20 @@ class Fp8MoEMethod:
max_w13_scales, requires_grad=False max_w13_scales, requires_grad=False
) )
if is_hip():
if bool(int(os.getenv("CK_MOE", "0"))):
layer.w13_weight = torch.nn.Parameter(
permute_weight(layer.w13_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
layer.w2_weight = torch.nn.Parameter(
permute_weight(layer.w2_weight.data),
requires_grad=False,
)
torch.cuda.empty_cache()
elif bool(int(os.getenv("MOE_PADDING", "0"))):
# If ROCm, apply weight padding (min. Mem channel contention) only if set # If ROCm, apply weight padding (min. Mem channel contention) only if set
if is_hip() and bool(int(os.getenv("MOE_PADDING", "0"))):
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0), F.pad(layer.w13_weight.data, (0, padding_size), "constant", 0),
requires_grad=False, requires_grad=False,
...@@ -752,6 +777,32 @@ class Fp8MoEMethod: ...@@ -752,6 +777,32 @@ class Fp8MoEMethod:
correction_bias=correction_bias, correction_bias=correction_bias,
) )
if is_hip() and bool(int(os.getenv("CK_MOE", "0"))):
import ater
from ater.fused_moe import fused_experts_ck
return fused_experts_ck(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
use_fp8_w8a8=True,
w1_scale=(
layer.w13_weight_scale_inv
if self.block_quant
else layer.w13_weight_scale
),
w2_scale=(
layer.w2_weight_scale_inv
if self.block_quant
else layer.w2_weight_scale
),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
)
else:
# Expert fusion with FP8 quantization # Expert fusion with FP8 quantization
return fused_experts( return fused_experts(
x, x,
...@@ -767,7 +818,9 @@ class Fp8MoEMethod: ...@@ -767,7 +818,9 @@ class Fp8MoEMethod:
else layer.w13_weight_scale else layer.w13_weight_scale
), ),
w2_scale=( w2_scale=(
layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale layer.w2_weight_scale_inv
if self.block_quant
else layer.w2_weight_scale
), ),
a1_scale=layer.w13_input_scale, a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale, a2_scale=layer.w2_input_scale,
......
...@@ -1340,6 +1340,25 @@ def parse_tool_response(text, tools, **kwargs): ...@@ -1340,6 +1340,25 @@ def parse_tool_response(text, tools, **kwargs):
return text, call_info_list return text, call_info_list
def permute_weight(x: torch.Tensor) -> torch.Tensor:
b_ = x.shape[0]
n_ = x.shape[1]
k_ = x.shape[2]
x_ = x
if x.dtype == torch.bfloat16 or x.dtype == torch.float16:
x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 32), 4, 8)
elif x.dtype == torch.float8_e4m3fnuz or x.dtype == torch.int8:
x_ = x_.view(int(b_), int(n_ / 16), 16, int(k_ / 64), 4, 16)
else:
return x_
x_ = x_.permute(0, 1, 3, 4, 2, 5)
x_ = x_.contiguous()
x_ = x_.view(*x.shape)
return x_
class MultiprocessingSerializer: class MultiprocessingSerializer:
@staticmethod @staticmethod
def serialize(obj): def serialize(obj):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment