Commit 38d80967 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.10.2rc2' into v0.10.2rc2-ori

parents 33650733 880c741b
...@@ -38,9 +38,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize): ...@@ -38,9 +38,7 @@ class MoEPrepareAndFinalizeNoEP(mk.FusedMoEPrepareAndFinalize):
expert_map: Optional[torch.Tensor], expert_map: Optional[torch.Tensor],
apply_router_weight_on_input: bool, apply_router_weight_on_input: bool,
quant_config: FusedMoEQuantConfig, quant_config: FusedMoEQuantConfig,
) -> tuple[torch.Tensor, Optional[torch.Tensor], ) -> mk.PrepareResultType:
Optional[mk.ExpertTokensMetadata], Optional[torch.Tensor],
Optional[torch.Tensor]]:
if apply_router_weight_on_input: if apply_router_weight_on_input:
topk = topk_ids.size(1) topk = topk_ids.size(1)
......
...@@ -420,9 +420,8 @@ def shuffle_weights( ...@@ -420,9 +420,8 @@ def shuffle_weights(
Args: Args:
*tensors: Variable number of torch.Tensor objects. *tensors: Variable number of torch.Tensor objects.
layout: A pair of integers specifying the layout: A pair of integers specifying the block sizes used to divide
block sizes used to divide the tensors during shuffling. the tensors during shuffling. Default is (16, 16).
Default is (16, 16).
Returns: Returns:
A Tuple of shuffled tensors. A Tuple of shuffled tensors.
......
...@@ -10,7 +10,7 @@ like uniform random routing. ...@@ -10,7 +10,7 @@ like uniform random routing.
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Optional from typing import Any, Optional
import torch import torch
...@@ -50,7 +50,9 @@ class DistributionBasedRouting(RoutingStrategy): ...@@ -50,7 +50,9 @@ class DistributionBasedRouting(RoutingStrategy):
distributions for testing different routing patterns. distributions for testing different routing patterns.
""" """
def __init__(self, distribution: str = "uniform", **distribution_params): def __init__(self,
distribution: str = "uniform",
**distribution_params: Any):
""" """
Initialize distribution-based routing. Initialize distribution-based routing.
...@@ -244,7 +246,7 @@ class RoutingSimulator: ...@@ -244,7 +246,7 @@ class RoutingSimulator:
cls._routing_strategies[name] = strategy cls._routing_strategies[name] = strategy
@classmethod @classmethod
def get_available_strategies(cls): def get_available_strategies(cls) -> list[str]:
""" """
Get list of available routing strategy names. Get list of available routing strategy names.
......
...@@ -9,11 +9,11 @@ import torch.nn as nn ...@@ -9,11 +9,11 @@ import torch.nn as nn
import vllm.envs as envs import vllm.envs as envs
from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op
def is_rocm_aiter_rmsnorm_enabled() -> bool: def is_rocm_aiter_rmsnorm_enabled() -> bool:
return current_platform.is_rocm() \ return envs.VLLM_ROCM_USE_AITER_RMSNORM \
and envs.VLLM_ROCM_USE_AITER_RMSNORM \
and envs.VLLM_ROCM_USE_AITER and envs.VLLM_ROCM_USE_AITER
...@@ -43,8 +43,22 @@ def fused_add_rms_norm( ...@@ -43,8 +43,22 @@ def fused_add_rms_norm(
return x, residual return x, residual
def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor, def poly_norm(x: torch.Tensor, weight: torch.Tensor, bias: torch.Tensor,
variance_epsilon: float) -> torch.Tensor: variance_epsilon: float) -> torch.Tensor:
from vllm import _custom_ops as ops
out = torch.empty_like(x)
ops.poly_norm(
out,
x,
weight,
bias,
variance_epsilon,
)
return out
def rocm_aiter_rms_norm_impl(x: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> torch.Tensor:
import aiter as rocm_aiter import aiter as rocm_aiter
if x.dim() > 2: if x.dim() > 2:
x_original_shape = x.shape x_original_shape = x.shape
...@@ -55,7 +69,7 @@ def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor, ...@@ -55,7 +69,7 @@ def rocm_aiter_rms_norm(x: torch.Tensor, weight: torch.Tensor,
return rocm_aiter.rms_norm(x, weight, variance_epsilon) return rocm_aiter.rms_norm(x, weight, variance_epsilon)
def rocm_aiter_fused_add_rms_norm( def rocm_aiter_rmsnorm2d_fwd_with_add_impl(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor, x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]: variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
...@@ -74,14 +88,48 @@ def rocm_aiter_fused_add_rms_norm( ...@@ -74,14 +88,48 @@ def rocm_aiter_fused_add_rms_norm(
return output, residual_out return output, residual_out
def dispatch_cuda_rmsnorm_func(add_residual: bool): def rocm_aiter_rms_norm_fake(x: torch.Tensor, weight: torch.Tensor,
if add_residual: variance_epsilon: float) -> torch.Tensor:
if is_rocm_aiter_rmsnorm_enabled(): return torch.empty_like(x)
return rocm_aiter_fused_add_rms_norm
return fused_add_rms_norm
def rocm_aiter_rmsnorm2d_fwd_with_add_fake(
x: torch.Tensor, residual: torch.Tensor, weight: torch.Tensor,
variance_epsilon: float) -> tuple[torch.Tensor, torch.Tensor]:
return torch.empty_like(x), torch.empty_like(residual)
if is_rocm_aiter_rmsnorm_enabled():
return rocm_aiter_rms_norm if current_platform.is_rocm():
direct_register_custom_op(
op_name="rocm_aiter_rms_norm",
op_func=rocm_aiter_rms_norm_impl,
mutates_args=[],
fake_impl=rocm_aiter_rms_norm_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_rmsnorm2d_fwd_with_add",
op_func=rocm_aiter_rmsnorm2d_fwd_with_add_impl,
mutates_args=[],
fake_impl=rocm_aiter_rmsnorm2d_fwd_with_add_fake,
dispatch_key=current_platform.dispatch_key,
)
def dispatch_rocm_rmsnorm_func(with_fused_add: bool, dtype: torch.dtype):
use_aiter = is_rocm_aiter_rmsnorm_enabled() and dtype in [
torch.float16, torch.bfloat16
]
if use_aiter and with_fused_add:
return torch.ops.vllm.rocm_aiter_rmsnorm2d_fwd_with_add
if use_aiter:
return torch.ops.vllm.rocm_aiter_rms_norm
# fall back to CUDA implementation
if with_fused_add:
return fused_add_rms_norm
return rms_norm return rms_norm
...@@ -114,6 +162,13 @@ class RMSNorm(CustomOp): ...@@ -114,6 +162,13 @@ class RMSNorm(CustomOp):
self.weight = torch.ones(hidden_size) self.weight = torch.ones(hidden_size)
if self.has_weight: if self.has_weight:
self.weight = nn.Parameter(self.weight) self.weight = nn.Parameter(self.weight)
weight_dtype = self.weight.data.dtype
if current_platform.is_rocm():
self.rocm_norm_func = dispatch_rocm_rmsnorm_func(
with_fused_add=False, dtype=weight_dtype)
self.rocm_norm_func_with_add = dispatch_rocm_rmsnorm_func(
with_fused_add=True, dtype=weight_dtype)
def forward_native( def forward_native(
self, self,
...@@ -162,13 +217,27 @@ class RMSNorm(CustomOp): ...@@ -162,13 +217,27 @@ class RMSNorm(CustomOp):
return self.forward_native(x, residual) return self.forward_native(x, residual)
add_residual = residual is not None add_residual = residual is not None
norm_func = dispatch_cuda_rmsnorm_func(add_residual) if add_residual:
return fused_add_rms_norm(x, residual, self.weight.data,
self.variance_epsilon)
else:
return rms_norm(x, self.weight.data, self.variance_epsilon)
def forward_hip(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if self.variance_size_override is not None:
return self.forward_native(x, residual)
add_residual = residual is not None
if add_residual: if add_residual:
return norm_func(x, residual, self.weight.data, return self.rocm_norm_func_with_add(x, residual, self.weight.data,
self.variance_epsilon) self.variance_epsilon)
else: else:
return norm_func(x, self.weight.data, self.variance_epsilon) return self.rocm_norm_func(x, self.weight.data,
self.variance_epsilon)
def forward_xpu( def forward_xpu(
self, self,
...@@ -265,3 +334,48 @@ class GemmaRMSNorm(CustomOp): ...@@ -265,3 +334,48 @@ class GemmaRMSNorm(CustomOp):
self.forward_static) self.forward_static)
self._is_compiled = True self._is_compiled = True
return self.forward_native(x, residual) return self.forward_native(x, residual)
@CustomOp.register("poly_norm")
class PolyNorm(CustomOp):
"""Polynomial normalization.
Computes x -> w_0 * RMSNorm(x^3) + w_1 * RMSNorm(x^2) + w_2 * RMSNorm(x) + b
where w_n is the learned weight and b is the bias.
Refer to https://arxiv.org/html/2411.03884v1
"""
def __init__(
self,
eps: float = 1e-6,
) -> None:
super().__init__()
self.weight = torch.nn.Parameter(torch.ones(3) / 3)
self.bias = torch.nn.Parameter(torch.zeros(1))
self.variance_epsilon = eps
def _norm(self, x):
return x / torch.sqrt(
x.pow(2).mean(-1, keepdim=True) + self.variance_epsilon)
def forward_native(
self,
x: torch.Tensor,
) -> torch.Tensor:
"""PyTorch-native implementation equivalent to forward().
Refer to https://github.com/BryceZhuo/PolyCom?tab=readme-ov-file/README.md
"""
orig_dtype = x.dtype
x_float = x.to(torch.float32)
output = (self.weight[0] * self._norm(x_float**3) +
self.weight[1] * self._norm(x_float**2) +
self.weight[2] * self._norm(x_float) + self.bias)
return output.to(orig_dtype)
def forward_cuda(
self,
x: torch.Tensor,
) -> torch.Tensor:
return poly_norm(x, self.weight, self.bias, self.variance_epsilon)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment