"vscode:/vscode.git/clone" did not exist on "a29e62ea3452bc6b1d4f3eeac2dc9a6b30357c4d"
Unverified Commit e1e0fd75 authored by Robert Shaw's avatar Robert Shaw Committed by GitHub
Browse files

[TPU] Avoid Triton Import (#15589)


Signed-off-by: default avatarrshaw@neuralmagic.com <robertgshaw2@gmail.com>
parent df8d3d12
...@@ -16,8 +16,6 @@ from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank, ...@@ -16,8 +16,6 @@ from vllm.distributed import (get_dp_group, get_tensor_model_parallel_rank,
from vllm.forward_context import ForwardContext, get_forward_context from vllm.forward_context import ForwardContext, get_forward_context
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled, shuffle_weights)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase) QuantizationConfig, QuantizeMethodBase)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
...@@ -119,7 +117,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -119,7 +117,9 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight( layer.w2_weight = torch.nn.Parameter(self._maybe_pad_weight(
layer.w2_weight.data), layer.w2_weight.data),
requires_grad=False) requires_grad=False)
# Lazy import to avoid importing triton.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
is_rocm_aiter_moe_enabled, shuffle_weights)
if is_rocm_aiter_moe_enabled(): if is_rocm_aiter_moe_enabled():
# reshaping weights is required for aiter moe kernel. # reshaping weights is required for aiter moe kernel.
shuffled_w13, shuffled_w2 = shuffle_weights( shuffled_w13, shuffled_w2 = shuffle_weights(
......
...@@ -13,9 +13,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size ...@@ -13,9 +13,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase, from vllm.model_executor.layers.fused_moe import (FusedMoE, FusedMoEMethodBase,
FusedMoeWeightScaleSupported) FusedMoeWeightScaleSupported)
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
expand_weights, is_rocm_aiter_block_scaled_moe_enabled,
is_rocm_aiter_moe_enabled, shuffle_weights)
from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase, from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
...@@ -532,6 +529,11 @@ class Fp8MoEMethod(FusedMoEMethodBase): ...@@ -532,6 +529,11 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer.w2_input_scale = None layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
# Lazy import to avoid importing triton too early.
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import (
expand_weights, is_rocm_aiter_block_scaled_moe_enabled,
is_rocm_aiter_moe_enabled, shuffle_weights)
# TODO (rob): refactor block quant into separate class. # TODO (rob): refactor block quant into separate class.
if self.block_quant: if self.block_quant:
assert self.quant_config.activation_scheme == "dynamic" assert self.quant_config.activation_scheme == "dynamic"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment