Commit 09396f62 authored by zhuwenwen's avatar zhuwenwen
Browse files

update layer.py

parent 3d062a1c
......@@ -28,8 +28,8 @@ from vllm.model_executor.layers.fused_moe.config import (
from vllm.model_executor.layers.fused_moe.modular_kernel import (
FusedMoEActivationFormat, FusedMoEModularKernel,
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled)
# from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
# is_rocm_aiter_moe_enabled)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs
......@@ -228,7 +228,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
self.topk_indices_dtype = None
self.moe = moe
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
# self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
self.rocm_aiter_moe_enabled = False
if self.rocm_aiter_moe_enabled:
from .rocm_aiter_fused_moe import rocm_aiter_fused_experts
self.rocm_aiter_fused_experts = rocm_aiter_fused_experts
......@@ -309,15 +310,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.w13_weight.data = self._maybe_pad_weight(layer.w13_weight.data)
layer.w2_weight.data = self._maybe_pad_weight(layer.w2_weight.data)
# Lazy import to avoid importing triton.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
shuffle_weights)
# from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
# shuffle_weights)
if self.rocm_aiter_moe_enabled:
shuffled_w13, shuffled_w2 = shuffle_weights(
layer.w13_weight.data, layer.w2_weight.data)
# if self.rocm_aiter_moe_enabled:
# shuffled_w13, shuffled_w2 = shuffle_weights(
# layer.w13_weight.data, layer.w2_weight.data)
layer.w13_weight.data = shuffled_w13
layer.w2_weight.data = shuffled_w2
# layer.w13_weight.data = shuffled_w13
# layer.w2_weight.data = shuffled_w2
if current_platform.is_cpu():
if current_platform.get_cpu_architecture() == CpuArchEnum.X86:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment