Unverified Commit 8a798be9 authored by Douglas Lehr's avatar Douglas Lehr Committed by GitHub
Browse files

[ROCm] Enable MXFP4 MoE weight pre-shuffling on gfx950 and update aiter (#34192)


Signed-off-by: default avatarDoug Lehr <douglehr@amd.com>
Co-authored-by: default avatarDoug Lehr <douglehr@amd.com>
Co-authored-by: default avatarGregory Shtrasberg <156009573+gshtras@users.noreply.github.com>
Co-authored-by: default avatartjtanaavllm <tunjian.tan@amd.com>
parent fb455ed5
ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:7.0-complete ARG BASE_IMAGE=rocm/dev-ubuntu-22.04:7.0-complete
ARG TRITON_BRANCH="f332c492" ARG TRITON_BRANCH="57c693b6"
ARG TRITON_REPO="https://github.com/ROCm/triton.git" ARG TRITON_REPO="https://github.com/ROCm/triton.git"
ARG PYTORCH_BRANCH="89075173" ARG PYTORCH_BRANCH="89075173"
ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git" ARG PYTORCH_REPO="https://github.com/ROCm/pytorch.git"
...@@ -9,7 +9,7 @@ ARG PYTORCH_AUDIO_BRANCH="v2.9.0" ...@@ -9,7 +9,7 @@ ARG PYTORCH_AUDIO_BRANCH="v2.9.0"
ARG PYTORCH_AUDIO_REPO="https://github.com/pytorch/audio.git" ARG PYTORCH_AUDIO_REPO="https://github.com/pytorch/audio.git"
ARG FA_BRANCH="0e60e394" ARG FA_BRANCH="0e60e394"
ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git" ARG FA_REPO="https://github.com/Dao-AILab/flash-attention.git"
ARG AITER_BRANCH="6af8b687" ARG AITER_BRANCH="v0.1.10.post2"
ARG AITER_REPO="https://github.com/ROCm/aiter.git" ARG AITER_REPO="https://github.com/ROCm/aiter.git"
ARG MORI_BRANCH="2d02c6a9" ARG MORI_BRANCH="2d02c6a9"
ARG MORI_REPO="https://github.com/ROCm/mori.git" ARG MORI_REPO="https://github.com/ROCm/mori.git"
...@@ -239,7 +239,7 @@ RUN pip install pyyaml && cd aiter \ ...@@ -239,7 +239,7 @@ RUN pip install pyyaml && cd aiter \
export HIP_CLANG_PATH=/opt/sccache-wrappers \ export HIP_CLANG_PATH=/opt/sccache-wrappers \
&& sccache --show-stats; \ && sccache --show-stats; \
fi \ fi \
&& PREBUILD_KERNELS=1 GPU_ARCHS=${AITER_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist \ && GPU_ARCHS=${AITER_ROCM_ARCH} python3 setup.py bdist_wheel --dist-dir=dist \
&& if [ "$USE_SCCACHE" = "1" ]; then sccache --show-stats; fi \ && if [ "$USE_SCCACHE" = "1" ]; then sccache --show-stats; fi \
&& ls /app/aiter/dist/*.whl && ls /app/aiter/dist/*.whl
RUN mkdir -p /app/install && cp /app/aiter/dist/*.whl /app/install RUN mkdir -p /app/install && cp /app/aiter/dist/*.whl /app/install
......
...@@ -933,7 +933,15 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod): ...@@ -933,7 +933,15 @@ class QuarkOCP_MX_MoEMethod(QuarkMoEMethod):
layer.w2_weight.view(self.fp4_dtype), layer.w2_weight.view(self.fp4_dtype),
requires_grad=layer.w2_weight.requires_grad, requires_grad=layer.w2_weight.requires_grad,
) )
# Pre-shuffle weight
shuffled_w13, shuffled_w2 = rocm_aiter_ops.shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data
)
layer.w13_weight = torch.nn.Parameter(shuffled_w13, requires_grad=False)
layer.w2_weight = torch.nn.Parameter(shuffled_w2, requires_grad=False)
layer.w13_weight.is_shuffled = True
layer.w2_weight.is_shuffled = True
torch.cuda.empty_cache() torch.cuda.empty_cache()
def get_fused_moe_quant_config( def get_fused_moe_quant_config(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment