Commit 4eabe123 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/releases/v0.9.0' into v0.9.0-ori

parents 45840cd2 58738772
...@@ -10,6 +10,7 @@ from typing import Literal, Optional, Union ...@@ -10,6 +10,7 @@ from typing import Literal, Optional, Union
from vllm.config import LoRAConfig from vllm.config import LoRAConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.model_loader.tensorizer import TensorizerConfig
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -89,12 +90,31 @@ class PEFTHelper: ...@@ -89,12 +90,31 @@ class PEFTHelper:
return cls(**filtered_dict) return cls(**filtered_dict)
@classmethod @classmethod
def from_local_dir(cls, lora_path: str, def from_local_dir(
max_position_embeddings: Optional[int]) -> "PEFTHelper": cls,
lora_path: str,
max_position_embeddings: Optional[int],
tensorizer_config_dict: Optional[dict] = None) -> "PEFTHelper":
lora_config_path = os.path.join(lora_path, "adapter_config.json") lora_config_path = os.path.join(lora_path, "adapter_config.json")
if tensorizer_config_dict:
tensorizer_config = TensorizerConfig(**tensorizer_config_dict)
tensorizer_args = tensorizer_config._construct_tensorizer_args()
from tensorizer.stream_io import open_stream
lora_config_path = os.path.join(tensorizer_config.lora_dir,
"adapter_config.json")
with open_stream(lora_config_path,
mode="rb",
**tensorizer_args.stream_params) as f:
config = json.load(f)
logger.info("Successfully deserialized LoRA config from %s",
tensorizer_config.lora_dir)
else:
with open(lora_config_path) as f: with open(lora_config_path) as f:
config = json.load(f) config = json.load(f)
config["vllm_max_position_embeddings"] = max_position_embeddings config["vllm_max_position_embeddings"] = max_position_embeddings
return cls.from_dict(config) return cls.from_dict(config)
......
...@@ -31,6 +31,7 @@ class LoRARequest( ...@@ -31,6 +31,7 @@ class LoRARequest(
lora_local_path: Optional[str] = msgspec.field(default=None) lora_local_path: Optional[str] = msgspec.field(default=None)
long_lora_max_len: Optional[int] = None long_lora_max_len: Optional[int] = None
base_model_name: Optional[str] = msgspec.field(default=None) base_model_name: Optional[str] = msgspec.field(default=None)
tensorizer_config_dict: Optional[dict] = None
def __post_init__(self): def __post_init__(self):
if self.lora_local_path: if self.lora_local_path:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import os import os
import re
from typing import Optional, Union from typing import Optional, Union
import huggingface_hub import huggingface_hub
import regex as re
from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError, from huggingface_hub.utils import (EntryNotFoundError, HfHubHTTPError,
HFValidationError, RepositoryNotFoundError) HFValidationError, RepositoryNotFoundError)
from torch import nn from torch import nn
......
...@@ -100,7 +100,8 @@ class WorkerLoRAManager(AbstractWorkerManager): ...@@ -100,7 +100,8 @@ class WorkerLoRAManager(AbstractWorkerManager):
lora_path = get_adapter_absolute_path(lora_request.lora_path) lora_path = get_adapter_absolute_path(lora_request.lora_path)
peft_helper = PEFTHelper.from_local_dir( peft_helper = PEFTHelper.from_local_dir(
lora_path, self.max_position_embeddings) lora_path, self.max_position_embeddings,
lora_request.tensorizer_config_dict)
# Validates the LoRA configuration against requirements before # Validates the LoRA configuration against requirements before
# loading weights, throwing an exception if validation fails. # loading weights, throwing an exception if validation fails.
...@@ -125,6 +126,7 @@ class WorkerLoRAManager(AbstractWorkerManager): ...@@ -125,6 +126,7 @@ class WorkerLoRAManager(AbstractWorkerManager):
self.lora_config.lora_extra_vocab_size, self.lora_config.lora_extra_vocab_size,
embedding_modules=self.embedding_modules, embedding_modules=self.embedding_modules,
embedding_padding_modules=self.embedding_padding_modules, embedding_padding_modules=self.embedding_padding_modules,
tensorizer_config_dict=lora_request.tensorizer_config_dict,
weights_mapper=hf_to_vllm_mapper) weights_mapper=hf_to_vllm_mapper)
except FileNotFoundError as e: except FileNotFoundError as e:
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import json import json
from re import escape as regex_escape
import llguidance import llguidance
from regex import escape as regex_escape
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.model_executor.guided_decoding.guidance_logits_processors import ( from vllm.model_executor.guided_decoding.guidance_logits_processors import (
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import copy
import os import os
from typing import Any from typing import Any
...@@ -34,9 +35,24 @@ class GuidanceLogitsProcessor: ...@@ -34,9 +35,24 @@ class GuidanceLogitsProcessor:
self.grammar = grammar self.grammar = grammar
self.tokenizer = tokenizer self.tokenizer = tokenizer
self.tokenizer_name = tokenizer.name_or_path self.tokenizer_name = tokenizer.name_or_path
self.ll_tokenizer = None
self.ll_matcher = None
self.bitmask = None
self.new_sampling = False self.new_sampling = False
self.initialized = False self.initialized = False
def clone(self) -> "GuidanceLogitsProcessor":
cloned = copy.copy(self)
if self.initialized:
cloned.ll_matcher = llguidance.LLMatcher(
self.ll_tokenizer, # type: ignore[assignment]
self.grammar,
log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
)
self.bitmask = llguidance.torch.allocate_token_bitmask(
1, self.ll_tokenizer.vocab_size) # type: ignore[attr-defined]
return cloned
def _initialize(self): def _initialize(self):
if self.initialized: if self.initialized:
return return
...@@ -56,7 +72,7 @@ class GuidanceLogitsProcessor: ...@@ -56,7 +72,7 @@ class GuidanceLogitsProcessor:
# create reusable bitmask # create reusable bitmask
self.bitmask = llguidance.torch.allocate_token_bitmask( self.bitmask = llguidance.torch.allocate_token_bitmask(
1, self.ll_tokenizer.vocab_size) 1, self.ll_tokenizer.vocab_size) # type: ignore[attr-defined]
self.initialized = True self.initialized = True
...@@ -70,15 +86,17 @@ class GuidanceLogitsProcessor: ...@@ -70,15 +86,17 @@ class GuidanceLogitsProcessor:
self._initialize() self._initialize()
if self.new_sampling and len(input_ids) > 0: if self.new_sampling and len(input_ids) > 0:
self.ll_matcher.consume_token(input_ids[-1]) self.ll_matcher.consume_token( # type: ignore[attr-defined]
err = self.ll_matcher.get_error() input_ids[-1])
err = self.ll_matcher.get_error() # type: ignore[attr-defined]
if err: if err:
logger.warning("Error in LLMatcher: %s", err) logger.warning("Error in LLMatcher: %s", err)
llguidance.torch.fill_next_token_bitmask(self.ll_matcher, self.bitmask, llguidance.torch.fill_next_token_bitmask(self.ll_matcher, self.bitmask,
0) 0)
llguidance.torch.apply_token_bitmask_inplace( llguidance.torch.apply_token_bitmask_inplace(
scores, self.bitmask.to(scores.device)) scores,
self.bitmask.to(scores.device)) # type: ignore[attr-defined]
self.new_sampling = True self.new_sampling = True
......
...@@ -5,9 +5,9 @@ import concurrent.futures ...@@ -5,9 +5,9 @@ import concurrent.futures
import os import os
from enum import Enum from enum import Enum
from json import dumps as json_dumps from json import dumps as json_dumps
from re import escape as regex_escape
from typing import Optional, Union from typing import Optional, Union
from regex import escape as regex_escape
from transformers import PreTrainedTokenizerBase from transformers import PreTrainedTokenizerBase
from vllm.model_executor.guided_decoding.outlines_logits_processors import ( from vllm.model_executor.guided_decoding.outlines_logits_processors import (
......
...@@ -56,6 +56,12 @@ class BaseLogitsProcessor: ...@@ -56,6 +56,12 @@ class BaseLogitsProcessor:
self._fsm_state: defaultdict[int, Union[int, self._fsm_state: defaultdict[int, Union[int,
CFGState]] = defaultdict(int) CFGState]] = defaultdict(int)
def clone(self) -> "BaseLogitsProcessor":
cloned = copy.copy(self)
cloned._guide = self._guide.copy()
cloned._fsm_state = copy.deepcopy(self._fsm_state)
return cloned
def __call__(self, input_ids: list[int], def __call__(self, input_ids: list[int],
scores: torch.Tensor) -> torch.Tensor: scores: torch.Tensor) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token.""" """Use the FSM to bias the logits before sampling the next token."""
...@@ -218,6 +224,12 @@ class CFGLogitsProcessor(BaseLogitsProcessor): ...@@ -218,6 +224,12 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
reasoner) reasoner)
self._guide = self._guide.copy() self._guide = self._guide.copy()
def clone(self) -> "CFGLogitsProcessor":
cloned = copy.copy(self)
cloned._fsm_state = copy.deepcopy(self._fsm_state)
cloned._guide = self._guide.copy()
return cloned
@lru_cache(maxsize=32) @lru_cache(maxsize=32)
def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase): def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
import re import regex as re
def has_xgrammar_unsupported_json_features(schema: dict) -> bool: def has_xgrammar_unsupported_json_features(schema: dict) -> bool:
......
...@@ -4,10 +4,10 @@ ...@@ -4,10 +4,10 @@
from __future__ import annotations from __future__ import annotations
import json import json
import re
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any from typing import TYPE_CHECKING, Any
import regex as re
import torch import torch
import vllm.envs import vllm.envs
...@@ -302,6 +302,7 @@ class XGrammarLogitsProcessor: ...@@ -302,6 +302,7 @@ class XGrammarLogitsProcessor:
prefilled: bool = field(default=False) prefilled: bool = field(default=False)
def __post_init__(self): def __post_init__(self):
if self.tokenizer_info is None:
self.tokenizer_info = self.config.tokenizer_info( self.tokenizer_info = self.config.tokenizer_info(
self.config.tokenizer_data) self.config.tokenizer_data)
...@@ -400,7 +401,8 @@ class XGrammarLogitsProcessor: ...@@ -400,7 +401,8 @@ class XGrammarLogitsProcessor:
def clone(self) -> XGrammarLogitsProcessor: def clone(self) -> XGrammarLogitsProcessor:
"""Create a new instance with shared compiled grammar """Create a new instance with shared compiled grammar
but separate state""" but separate state"""
new_processor = XGrammarLogitsProcessor(self.config, self.reasoner) new_processor = XGrammarLogitsProcessor(self.config, self.reasoner,
None, self.tokenizer_info)
# Share the compiled grammar context (immutable after compilation) # Share the compiled grammar context (immutable after compilation)
new_processor.ctx = self.ctx new_processor.ctx = self.ctx
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
""" CUTLASS based Fused MoE kernels.""" """ CUTLASS based Fused MoE kernels."""
import os
from typing import Optional from typing import Optional
import torch import torch
...@@ -271,8 +270,6 @@ def cutlass_moe_fp8( ...@@ -271,8 +270,6 @@ def cutlass_moe_fp8(
FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max() FLOAT4_E2M1_MAX = scalar_types.float4_e2m1f.max()
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
MAX_TOKENS_PER_EXPERT = int(
os.environ.get('VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT', '65536'))
def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
...@@ -330,10 +327,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, ...@@ -330,10 +327,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype" assert a.dtype in [torch.half, torch.bfloat16], "Invalid input dtype"
assert (topk_weights.shape[0] == m and topk_ids.shape[0] assert (topk_weights.shape[0] == m and topk_ids.shape[0]
== m), ("topk must be provided for each row of a") == m), ("topk must be provided for each row of a")
assert (m <= MAX_TOKENS_PER_EXPERT), (
f"m must be less than MAX_TOKENS_PER_EXPERT({MAX_TOKENS_PER_EXPERT})"
f" for cutlass_moe_fp4, observed m = {m}. Use"
f" VLLM_MODELOPT_MAX_TOKENS_PER_EXPERT to set this value.")
out_dtype = a.dtype out_dtype = a.dtype
num_topk = topk_ids.shape[1] num_topk = topk_ids.shape[1]
...@@ -362,8 +356,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, ...@@ -362,8 +356,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
expert_offsets, expert_offsets,
blockscale_offsets, blockscale_offsets,
num_topk, num_topk,
expert_map=a_map, expert_map=a_map)
MAX_TOKENS_PER_EXPERT=MAX_TOKENS_PER_EXPERT)
c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale, c1 = ops.cutlass_fp4_moe_mm(rep_a_fp4, w1_fp4, rep_a_blockscale,
w1_blockscale, w1_alphas, problem_sizes1, w1_blockscale, w1_alphas, problem_sizes1,
...@@ -378,12 +371,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor, ...@@ -378,12 +371,7 @@ def cutlass_moe_fp4(a: torch.Tensor, a1_gscale: torch.Tensor,
torch.ops._C.silu_and_mul(intermediate, c1) torch.ops._C.silu_and_mul(intermediate, c1)
int_fp4, int_blockscale = ops.scaled_fp4_experts_quant( int_fp4, int_blockscale = ops.scaled_fp4_experts_quant(
intermediate, intermediate, a2_gscale, expert_offsets, blockscale_offsets, num_topk)
a2_gscale,
expert_offsets,
blockscale_offsets,
num_topk,
MAX_TOKENS_PER_EXPERT=MAX_TOKENS_PER_EXPERT)
c2 = ops.cutlass_fp4_moe_mm(int_fp4, w2_fp4, int_blockscale, w2_blockscale, c2 = ops.cutlass_fp4_moe_mm(int_fp4, w2_fp4, int_blockscale, w2_blockscale,
w2_alphas, problem_sizes2, expert_offsets[:-1], w2_alphas, problem_sizes2, expert_offsets[:-1],
......
...@@ -2,13 +2,11 @@ ...@@ -2,13 +2,11 @@
import os import os
import importlib import importlib
import threading
from abc import abstractmethod from abc import abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from typing import Callable, Optional from typing import Callable, Optional
from weakref import WeakValueDictionary
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
...@@ -45,6 +43,7 @@ if current_platform.is_cuda_alike(): ...@@ -45,6 +43,7 @@ if current_platform.is_cuda_alike():
from .pplx_prepare_finalize import PplxPrepareAndFinalize from .pplx_prepare_finalize import PplxPrepareAndFinalize
else: else:
fused_experts = None # type: ignore fused_experts = None # type: ignore
FusedMoEPermuteExpertsUnpermute = None # type: ignore
FusedMoEPrepareAndFinalize = None # type: ignore FusedMoEPrepareAndFinalize = None # type: ignore
if is_rocm_aiter_moe_enabled(): if is_rocm_aiter_moe_enabled():
from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501 from vllm.model_executor.layers.fused_moe.rocm_aiter_fused_moe import ( # noqa: E501
...@@ -52,8 +51,7 @@ if is_rocm_aiter_moe_enabled(): ...@@ -52,8 +51,7 @@ if is_rocm_aiter_moe_enabled():
else: else:
from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk from vllm.model_executor.layers.fused_moe.fused_moe import grouped_topk
if current_platform.is_tpu(): if current_platform.is_tpu():
# the iterative moe implementation is used until the moe_pallas is fixed from .moe_pallas import fused_moe as fused_moe_pallas
from .moe_torch_iterative import fused_moe as fused_moe_pallas
else: else:
fused_moe_pallas = None # type: ignore fused_moe_pallas = None # type: ignore
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -76,7 +74,8 @@ class FusedMoEParallelConfig: ...@@ -76,7 +74,8 @@ class FusedMoEParallelConfig:
@property @property
def use_pplx_kernels(self): def use_pplx_kernels(self):
return self.dp_size > 1 and self.use_ep and has_pplx return self.dp_size > 1 and self.use_ep and \
envs.VLLM_ALL2ALL_BACKEND == "pplx"
@staticmethod @staticmethod
def make(tp_size_: int, dp_size_: int, def make(tp_size_: int, dp_size_: int,
...@@ -199,6 +198,8 @@ class MoEConfig: ...@@ -199,6 +198,8 @@ class MoEConfig:
# TODO: add more quantization params, blocked, per-token, etc. # TODO: add more quantization params, blocked, per-token, etc.
block_size: int = 128 block_size: int = 128
max_num_tokens: int = MOE_DP_CHUNK_SIZE
@property @property
def tp_size(self): def tp_size(self):
return self.moe_parallel_config.tp_size return self.moe_parallel_config.tp_size
...@@ -247,13 +248,59 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -247,13 +248,59 @@ class FusedMoEMethodBase(QuantizeMethodBase):
params_dtype: torch.dtype, **extra_weight_attrs): params_dtype: torch.dtype, **extra_weight_attrs):
raise NotImplementedError raise NotImplementedError
def set_prepare_finalize( def init_prepare_finalize(self, moe: MoEConfig,
self, quant_config: Optional[QuantizationConfig]):
dp_size: int, all2all_manager = get_ep_group().device_communicator.all2all_manager
world_size: int, assert all2all_manager is not None
prepare_finalize: FusedMoEPrepareAndFinalize,
) -> bool: prepare_finalize = None
return False if moe.use_pplx_kernels:
all_to_all_args = dict(
max_num_tokens=moe.max_num_tokens,
num_experts=moe.num_experts,
experts_per_token=moe.experts_per_token, # topk
rank=all2all_manager.rank,
world_size=all2all_manager.world_size,
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
hidden_dim=moe.hidden_dim,
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
# For blocked per token: set to
# ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to sizeof(float32)
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else (
(moe.hidden_dim + moe.block_size - 1) // moe.block_size *
torch.float32.itemsize)),
group_name=all2all_manager.cpu_group.group_name,
)
handle = all2all_manager.get_handle(all_to_all_args)
prepare_finalize = PplxPrepareAndFinalize(
handle,
max_num_tokens=moe.max_num_tokens,
world_size=all2all_manager.world_size,
rank=all2all_manager.rank,
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
quant_dtype=moe.in_dtype,
)
if prepare_finalize is not None:
experts = self.select_gemm_impl(prepare_finalize)
self.fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
)
def select_gemm_impl(
self, prepare_finalize: Optional[FusedMoEPrepareAndFinalize]
) -> FusedMoEPermuteExpertsUnpermute:
# based on the all2all implementation, select the appropriate
# gemm implementation
raise NotImplementedError(
"Subclass must select appropriate gemm implementation"
" based on the prepare_finalize")
@abstractmethod @abstractmethod
def apply( def apply(
...@@ -277,53 +324,13 @@ class FusedMoEMethodBase(QuantizeMethodBase): ...@@ -277,53 +324,13 @@ class FusedMoEMethodBase(QuantizeMethodBase):
raise NotImplementedError raise NotImplementedError
class AllToAllCache:
def __init__(self):
self._cache: WeakValueDictionary = WeakValueDictionary()
self._lock = threading.RLock() # Reentrant lock for thread safety
def destroy(self):
with self._lock:
# TODO: can we do del self._cache?
for _, a2a in self._cache.items():
a2a.destroy()
def get_or_create(self, **kwargs):
assert has_pplx
import pplx_kernels as pplx
# Create a hashable key from the kwargs
key = tuple(sorted((k, v) for k, v in kwargs.items()))
with self._lock:
instance = self._cache.get(key)
if instance is None:
# TODO (varun): Add support to switch to intranode
# when all communications are within the same
# node.
logger.debug("Create AllToAll %s", kwargs)
instance = pplx.AllToAll.internode(**kwargs)
self._cache[key] = instance
return instance
# Global singleton
_all_to_all_cache = AllToAllCache()
# Factory function as a cleaner interface
def get_all_to_all(**kwargs):
return _all_to_all_cache.get_or_create(**kwargs)
@CustomOp.register("unquantized_fused_moe") @CustomOp.register("unquantized_fused_moe")
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization.""" """MoE method without quantization."""
def __init__(self, moe: MoEConfig): def __init__(self, moe: MoEConfig):
super().__init__() super().__init__()
self.fused_experts = fused_experts self.fused_experts = fused_experts # type: ignore
self.moe = moe self.moe = moe
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
...@@ -333,6 +340,42 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -333,6 +340,42 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
else: else:
self.rocm_aiter_fused_experts = None # type: ignore self.rocm_aiter_fused_experts = None # type: ignore
def select_gemm_impl(
self, prepare_finalize: Optional[FusedMoEPrepareAndFinalize]):
assert self.fused_experts == fused_experts
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
if isinstance(prepare_finalize,
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
logger.debug("BatchedTritonExperts %s", self.moe)
experts = BatchedTritonExperts(
max_num_tokens=MOE_DP_CHUNK_SIZE,
world_size=all2all_manager.world_size,
# dp_size actually means tp_size, bug in pplx kernels
dp_size=all2all_manager.tp_group.world_size,
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
)
else:
logger.debug("TritonExperts %s", self.moe)
experts = TritonExperts(
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
per_channel_quant=False,
)
return experts
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int, hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, use_nn_moe: bool, params_dtype: torch.dtype, use_nn_moe: bool,
...@@ -392,10 +435,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -392,10 +435,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
shuffle_weights) shuffle_weights)
if self.rocm_aiter_moe_enabled: if self.rocm_aiter_moe_enabled:
# use 2stage ck moe layout shuffled_w13, shuffled_w2 = shuffle_weights(
shuffled_w13, shuffled_w2 = shuffle_weights(layer.w13_weight.data, layer.w13_weight.data, layer.w2_weight.data)
layer.w2_weight.data,
layout=(32, 32))
layer.w13_weight.data = shuffled_w13 layer.w13_weight.data = shuffled_w13
layer.w2_weight.data = shuffled_w2 layer.w2_weight.data = shuffled_w2
...@@ -448,47 +489,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -448,47 +489,6 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
apply_router_weight_on_input=apply_router_weight_on_input, apply_router_weight_on_input=apply_router_weight_on_input,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe)
def set_prepare_finalize(
self,
dp_size: int,
world_size: int,
prepare_finalize: FusedMoEPrepareAndFinalize,
) -> bool:
assert self.fused_experts == fused_experts
experts: Optional[FusedMoEPermuteExpertsUnpermute] = None
if isinstance(prepare_finalize,
(BatchedPrepareAndFinalize, PplxPrepareAndFinalize)):
logger.debug("BatchedTritonExperts %s", self.moe)
experts = BatchedTritonExperts(
max_num_tokens=MOE_DP_CHUNK_SIZE,
world_size=world_size,
dp_size=dp_size,
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
)
else:
logger.debug("TritonExperts %s", self.moe)
experts = TritonExperts(
use_fp8_w8a8=False,
use_int8_w8a8=False,
use_int8_w8a16=False,
use_int4_w4a16=False,
block_shape=None,
per_channel_quant=False,
)
self.fused_experts = FusedMoEModularKernel(
prepare_finalize,
experts,
)
return True
def forward_cuda( def forward_cuda(
self, self,
layer: torch.nn.Module, layer: torch.nn.Module,
...@@ -702,45 +702,6 @@ def determine_expert_map( ...@@ -702,45 +702,6 @@ def determine_expert_map(
return (local_num_experts, expert_map) return (local_num_experts, expert_map)
def _construct_prepare_finalize(
moe: MoEConfig, quant_config: Optional[QuantizationConfig]
) -> Optional[FusedMoEPrepareAndFinalize]:
max_num_tokens = MOE_DP_CHUNK_SIZE
world_size = moe.ep_size
dp_size = moe.ep_size // moe.dp_size # dp_size actually means TP.
rank = moe.ep_rank
if moe.use_pplx_kernels:
logger.debug("using PplxPrepareAndFinalize")
all_to_all = get_all_to_all(
max_num_tokens=max_num_tokens,
num_experts=moe.num_experts,
experts_per_token=moe.experts_per_token, # topk
rank=rank,
world_size=world_size,
dp_size=dp_size,
hidden_dim=moe.hidden_dim,
hidden_dim_bytes=moe.hidden_dim * moe.in_dtype.itemsize,
# For blocked per token: set to
# ceil_div(hidden_dim, block_size) * sizeof(float32)
# For per-token: set to sizeof(float32)
hidden_dim_scale_bytes=(0 if moe.in_dtype.itemsize != 1 else
((moe.hidden_dim + moe.block_size - 1) //
moe.block_size * torch.float32.itemsize)))
return PplxPrepareAndFinalize(
all_to_all,
max_num_tokens=max_num_tokens,
world_size=world_size,
rank=rank,
dp_size=dp_size,
quant_dtype=moe.in_dtype,
)
return None
class FusedMoE(torch.nn.Module): class FusedMoE(torch.nn.Module):
"""FusedMoE layer for MoE models. """FusedMoE layer for MoE models.
...@@ -854,7 +815,10 @@ class FusedMoE(torch.nn.Module): ...@@ -854,7 +815,10 @@ class FusedMoE(torch.nn.Module):
moe_parallel_config=self.moe_parallel_config, moe_parallel_config=self.moe_parallel_config,
# TODO (bnell): this needs to be fixed for quantized types. # TODO (bnell): this needs to be fixed for quantized types.
in_dtype=params_dtype, in_dtype=params_dtype,
max_num_tokens=MOE_DP_CHUNK_SIZE,
) )
self.moe_config = moe
self.quant_config = quant_config
# Note: get_quant_method will look at the layer's local_num_experts # Note: get_quant_method will look at the layer's local_num_experts
# for heuristic purposes, so it must be initialized first. # for heuristic purposes, so it must be initialized first.
...@@ -862,25 +826,13 @@ class FusedMoE(torch.nn.Module): ...@@ -862,25 +826,13 @@ class FusedMoE(torch.nn.Module):
if quant_config is None: if quant_config is None:
quant_method = UnquantizedFusedMoEMethod(moe) quant_method = UnquantizedFusedMoEMethod(moe)
prepare_finalize = _construct_prepare_finalize(moe, quant_config)
else: else:
quant_method = quant_config.get_quant_method(self, prefix) quant_method = quant_config.get_quant_method(self, prefix)
# No pplx for quantized types yet.
prepare_finalize = None
assert quant_method is not None assert quant_method is not None
assert isinstance(quant_method, FusedMoEMethodBase) assert isinstance(quant_method, FusedMoEMethodBase)
self.quant_method = quant_method self.quant_method = quant_method
if prepare_finalize is not None:
world_size = moe.ep_size
dp_size = int(moe.ep_size // moe.dp_size)
success = self.quant_method.set_prepare_finalize(
dp_size, world_size, prepare_finalize)
if not success:
logger.warning("DP+EP not supported for %s.",
type(self.quant_method))
if quant_config is None: if quant_config is None:
# Not considering quant for now, temporarily # Not considering quant for now, temporarily
# self.use_nn_moe = int(os.environ.get('MOE_NN', 1)) == 1 # self.use_nn_moe = int(os.environ.get('MOE_NN', 1)) == 1
......
...@@ -2,7 +2,23 @@ ...@@ -2,7 +2,23 @@
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from torch_xla.experimental.custom_kernel import _histogram
def _histogram(input: torch.Tensor, min: int, max: int) -> torch.Tensor:
"""
Compute the histogram of a int32 tensor. The bin edges are defined by the
min and max values, with step = 1.
"""
assert input.dtype == torch.int32, "input must be of torch.int32 dtype."
assert min <= max, "min must be less than or equal to max."
def searchsorted(sorted_sequence: torch.Tensor,
values_to_search: torch.Tensor) -> torch.Tensor:
return (sorted_sequence.unsqueeze(1) == values_to_search).sum(dim=1)
bin_edges = torch.linspace(min, max, max - min + 1,
dtype=input.dtype).to(input.device)
return searchsorted(bin_edges, input).to(torch.int32)
def fused_moe( def fused_moe(
...@@ -61,7 +77,7 @@ def fused_moe( ...@@ -61,7 +77,7 @@ def fused_moe(
x = torch.ops.xla.gmm(x, w2, group_sizes) x = torch.ops.xla.gmm(x, w2, group_sizes)
x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size) x = x[topk_argsort_revert_indices].reshape(-1, topk, hidden_size)
x = x * topk_weights.unsqueeze_(dim=-1) x = x * topk_weights.unsqueeze(dim=-1)
x = x.sum(dim=-2) x = x.sum(dim=-2)
x = x.reshape(orig_shape) x = x.reshape(orig_shape)
return x return x
...@@ -182,3 +182,7 @@ def moe_unpermute( ...@@ -182,3 +182,7 @@ def moe_unpermute(
expert_first_token_offset, n_expert, expert_first_token_offset, n_expert,
n_local_expert, topk, hidden_states) n_local_expert, topk, hidden_states)
return hidden_states return hidden_states
def moe_permute_unpermute_supported():
return torch.ops._moe_C.moe_permute_unpermute_supported()
...@@ -9,7 +9,6 @@ from vllm.model_executor.layers.fused_moe.utils import ( ...@@ -9,7 +9,6 @@ from vllm.model_executor.layers.fused_moe.utils import (
moe_kernel_quantize_input) moe_kernel_quantize_input)
# Note use: layer.get_all_to_all() to get an AllToAll instance
# The max_num_tokens, world_size and dp_size must be the same # The max_num_tokens, world_size and dp_size must be the same
# as the ones used to create the AllToAll. # as the ones used to create the AllToAll.
class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize): class PplxPrepareAndFinalize(mk.FusedMoEPrepareAndFinalize):
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from enum import IntEnum
from functools import cache from functools import cache
from typing import Optional from typing import Optional
...@@ -9,6 +10,28 @@ from vllm.platforms import current_platform ...@@ -9,6 +10,28 @@ from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
class QuantMethod(IntEnum):
# This allows interfacing with AITER QuantType Enum
# without importing the QuantType from AITER globally.
# Note that these quantization methods are
# supported in AITER package. However,
# not all are used in this module.
NO = 0 # a16w16
PER_TENSOR = 1 # w8a8 (pre_Tensor)
PER_TOKEN = 2 # w8a8/w8a4 (per_Token)
BLOCK_1X128 = 3 # block quantized w8a8 (per_1x128)
BLOCK_128x128 = 4 # block quantized w8a8 (per_128x128)
class ActivationMethod(IntEnum):
# This allows interfacing with AITER ActivationType enum
# without importing the ActivationType enum from AITER globally.
SILU = 0
GELU = 1
@cache @cache
def is_rocm_aiter_moe_enabled() -> bool: def is_rocm_aiter_moe_enabled() -> bool:
return current_platform.is_rocm() \ return current_platform.is_rocm() \
...@@ -29,13 +52,12 @@ def rocm_aiter_asm_moe_tkw1_impl( ...@@ -29,13 +52,12 @@ def rocm_aiter_asm_moe_tkw1_impl(
a16: bool = False, a16: bool = False,
per_tensor_quant_scale: Optional[torch.Tensor] = None, per_tensor_quant_scale: Optional[torch.Tensor] = None,
expert_mask: Optional[torch.Tensor] = None, expert_mask: Optional[torch.Tensor] = None,
activation_str: str = "silu") -> torch.Tensor: activation_method: int = ActivationMethod.SILU.value) -> torch.Tensor:
from aiter import ActivationType from aiter import ActivationType
from aiter.fused_moe_bf16_asm import asm_moe_tkw1 from aiter.fused_moe_bf16_asm import asm_moe_tkw1
activation = \ activation = ActivationType(activation_method)
ActivationType.Gelu if activation_str == "gelu" else ActivationType.Silu
return asm_moe_tkw1(hidden_states, return asm_moe_tkw1(hidden_states,
w1, w1,
...@@ -65,163 +87,7 @@ def rocm_aiter_asm_moe_tkw1_fake( ...@@ -65,163 +87,7 @@ def rocm_aiter_asm_moe_tkw1_fake(
a16: bool = False, a16: bool = False,
per_tensor_quant_scale: Optional[torch.Tensor] = None, per_tensor_quant_scale: Optional[torch.Tensor] = None,
expert_mask: Optional[torch.Tensor] = None, expert_mask: Optional[torch.Tensor] = None,
activation_str: str = "silu") -> torch.Tensor: activation_method: int = ActivationMethod.SILU.value) -> torch.Tensor:
return torch.empty_like(hidden_states)
def rocm_aiter_fmoe_fp8_blockscale_g1u1_impl(
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
hidden_states_dtype: torch.dtype,
expert_mask: torch.Tensor,
a1: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: torch.Tensor,
block_shape: list[int],
smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
from aiter import fmoe_fp8_blockscale_g1u1
from aiter.fused_moe_bf16_asm import moe_sorting_ck
topk = topk_ids.shape[1]
model_dim = w1.shape[-1]
local_E = E = w1.shape[0]
if expert_mask is not None:
E = expert_mask.numel()
(
sorted_token_ids,
sorted_weight_buf,
sorted_expert_ids,
num_valid_ids,
out_asm,
) = moe_sorting_ck(topk_ids,
topk_weights,
E,
model_dim,
hidden_states_dtype,
expert_mask=expert_mask)
fmoe_fp8_blockscale_g1u1(out_asm, a1, w1, w2, sorted_token_ids,
sorted_weight_buf, sorted_expert_ids,
num_valid_ids, topk,
a1_scale.t().contiguous(),
w1_scale.view(local_E, -1),
w2_scale.view(local_E,
-1), *block_shape, smooth_scale)
return out_asm
def rocm_aiter_fmoe_fp8_blockscale_g1u1_fake(
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
hidden_states_dtype: torch.dtype,
expert_mask: torch.Tensor,
a1: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
w1_scale: torch.Tensor,
w2_scale: torch.Tensor,
a1_scale: torch.Tensor,
block_shape: list[int],
smooth_scale: Optional[torch.Tensor] = None) -> torch.Tensor:
return torch.empty_like(a1, dtype=hidden_states_dtype)
def rocm_aiter_asm_moe_impl(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
fc1_smooth_scale: Optional[torch.Tensor] = None,
fc2_smooth_scale: Optional[torch.Tensor] = None,
a16: bool = False,
activation: str = "silu") -> torch.Tensor:
import aiter.fused_moe_bf16_asm as rocm_aiter_asm_fmoe
from aiter import ActivationType
assert activation in ["silu", "gelu"], "The given activation:" \
f" {activation}" \
" is not supported in" \
" AITER."
if activation == "silu":
aiter_activation = ActivationType.Silu
else:
aiter_activation = ActivationType.Gelu
return rocm_aiter_asm_fmoe.asm_moe(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weight=topk_weights,
topk_ids=topk_ids,
fc1_scale=fc1_scale,
fc2_scale=fc2_scale,
fc1_smooth_scale=fc1_smooth_scale,
fc2_smooth_scale=fc2_smooth_scale,
a16=a16,
activation=aiter_activation)
def rocm_aiter_asm_moe_fake(hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
fc1_smooth_scale: Optional[torch.Tensor] = None,
fc2_smooth_scale: Optional[torch.Tensor] = None,
a16: bool = False,
activation: str = "silu") -> torch.Tensor:
return torch.empty_like(hidden_states)
def rocm_aiter_ck_moe_2stages_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_size: Optional[list[int]] = None,
expert_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from aiter.fused_moe_bf16_asm import ck_moe_2stages
return ck_moe_2stages(a1=hidden_states,
w1=w1,
w2=w2,
topk_weight=topk_weights,
topk_ids=topk_ids,
fc1_scale=fc1_scale,
fc2_scale=fc2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_size=block_size,
expert_mask=expert_mask)
def rocm_aiter_ck_moe_2stages_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
fc1_scale: Optional[torch.Tensor] = None,
fc2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_size: Optional[list[int]] = None,
expert_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.empty_like(hidden_states) return torch.empty_like(hidden_states)
...@@ -274,6 +140,50 @@ def rocm_aiter_biased_grouped_topk_fake( ...@@ -274,6 +140,50 @@ def rocm_aiter_biased_grouped_topk_fake(
pass pass
def rocm_aiter_fused_moe_impl(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
expert_mask: Optional[torch.Tensor] = None,
activation_method: int = ActivationMethod.SILU.value,
quant_method: int = QuantMethod.NO.value,
doweight_stage1: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
from aiter import ActivationType, QuantType
from aiter.fused_moe import fused_moe
activation = ActivationType(activation_method)
quant_type = QuantType(quant_method)
return fused_moe(hidden_states, w1, w2, topk_weight, topk_ids, expert_mask,
activation, quant_type, doweight_stage1, w1_scale,
w2_scale, a1_scale, a2_scale)
def rocm_aiter_fused_moe_fake(
hidden_states: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_weight: torch.Tensor,
topk_ids: torch.Tensor,
expert_mask: Optional[torch.Tensor] = None,
activation_method: int = ActivationMethod.SILU.value,
quant_method: int = QuantMethod.NO.value,
doweight_stage1: bool = False,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.empty_like(hidden_states)
if current_platform.is_rocm(): if current_platform.is_rocm():
direct_register_custom_op( direct_register_custom_op(
...@@ -285,26 +195,10 @@ if current_platform.is_rocm(): ...@@ -285,26 +195,10 @@ if current_platform.is_rocm():
) )
direct_register_custom_op( direct_register_custom_op(
op_name="rocm_aiter_fmoe_fp8_blockscale_g1u1", op_name="rocm_aiter_fused_moe",
op_func=rocm_aiter_fmoe_fp8_blockscale_g1u1_impl, op_func=rocm_aiter_fused_moe_impl,
mutates_args=[], mutates_args=[],
fake_impl=rocm_aiter_fmoe_fp8_blockscale_g1u1_fake, fake_impl=rocm_aiter_fused_moe_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_asm_moe",
op_func=rocm_aiter_asm_moe_impl,
mutates_args=[],
fake_impl=rocm_aiter_asm_moe_fake,
dispatch_key=current_platform.dispatch_key,
)
direct_register_custom_op(
op_name="rocm_aiter_ck_moe_2stages",
op_func=rocm_aiter_ck_moe_2stages_impl,
mutates_args=[],
fake_impl=rocm_aiter_ck_moe_2stages_fake,
dispatch_key=current_platform.dispatch_key, dispatch_key=current_platform.dispatch_key,
) )
...@@ -373,32 +267,14 @@ def rocm_aiter_fused_experts( ...@@ -373,32 +267,14 @@ def rocm_aiter_fused_experts(
a2_scale: Optional[torch.Tensor] = None, a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None) -> torch.Tensor: block_shape: Optional[list[int]] = None) -> torch.Tensor:
from vllm.model_executor.layers.quantization.utils.fp8_utils import ( activation_method = (ActivationMethod.SILU
per_token_group_quant_fp8) if activation == "silu" else ActivationMethod.GELU)
# All AITER Fused MoE kernels are expecting the following datatypes # All AITER Fused MoE kernels are expecting the following datatypes
topk_weights = topk_weights.to(torch.float32) topk_weights = topk_weights.to(torch.float32)
topk_ids = topk_ids.to(torch.int32) topk_ids = topk_ids.to(torch.int32)
# w8a8 block-scaled
if block_shape is not None and use_fp8_w8a8:
assert not apply_router_weight_on_input, (
"apply_router_weight_on_input is not supported for block scaled moe"
)
assert w1_scale is not None
assert w2_scale is not None
# The default block sizes are 128 in AITER.
block_shape = [128, 128] if block_shape is None else block_shape
a1, a1_scale = per_token_group_quant_fp8(hidden_states, block_shape[1])
return torch.ops.vllm.rocm_aiter_fmoe_fp8_blockscale_g1u1(
topk_ids, topk_weights, hidden_states.dtype, None, a1, w1, w2,
w1_scale, w2_scale, a1_scale, block_shape, None)
# w8a8 per-channel quantization # w8a8 per-channel quantization
elif per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8: if per_channel_quant and apply_router_weight_on_input and use_fp8_w8a8:
# AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input` # AITER tkw1 kernel for FP8 models with `apply_router_weight_on_input`
# This applies topk_weights on the GEMM output of the first FC layer # This applies topk_weights on the GEMM output of the first FC layer
# rather than the second FC. # rather than the second FC.
...@@ -421,42 +297,23 @@ def rocm_aiter_fused_experts( ...@@ -421,42 +297,23 @@ def rocm_aiter_fused_experts(
a16=False, a16=False,
per_tensor_quant_scale=None, per_tensor_quant_scale=None,
expert_mask=None, expert_mask=None,
activation_str=activation) activation_method=activation_method)
# w8a8 per-tensor activation per-tensor weight else:
elif use_fp8_w8a8: quant_method = QuantMethod.NO.value
# w8a8 block-scaled
if block_shape is not None and use_fp8_w8a8:
assert not apply_router_weight_on_input, ( assert not apply_router_weight_on_input, (
"apply_router_weight_on_input is not supported for fp8_w8a8") "apply_router_weight_on_input is\
not supported for block scaled moe")
# - faster static per-tensor-activation static per-tensor-weight assert w1_scale is not None
# fp8 quantization w8a8 assert w2_scale is not None
if a1_scale is not None and a2_scale is not None: quant_method = QuantMethod.BLOCK_128x128.value
return torch.ops.vllm.rocm_aiter_ck_moe_2stages( elif use_fp8_w8a8:
hidden_states=hidden_states, # Currently only per tensor quantization method is enabled.
w1=w1, quant_method = QuantMethod.PER_TENSOR.value
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
fc1_scale=w1_scale,
fc2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale)
# - fallback static per-tensor-activation static per-tensor-weight
# fp8 quantization w8a8
# - dynamic per-tensor activation static per-tensor-weight
# fp8 quantization w8a8
return torch.ops.vllm.rocm_aiter_asm_moe(hidden_states=hidden_states,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
fc1_scale=w1_scale,
fc2_scale=w2_scale,
fc1_smooth_scale=None,
fc2_smooth_scale=None,
a16=False,
activation=activation)
if apply_router_weight_on_input: if apply_router_weight_on_input:
assert (topk_weights.dim() == 2 assert (topk_weights.dim() == 2
), "`topk_weights` should be in shape (num_tokens, topk)" ), "`topk_weights` should be in shape (num_tokens, topk)"
...@@ -465,16 +322,19 @@ def rocm_aiter_fused_experts( ...@@ -465,16 +322,19 @@ def rocm_aiter_fused_experts(
topk == 1 topk == 1
), "Only support topk=1 when `apply_router_weight_on_input` is True" ), "Only support topk=1 when `apply_router_weight_on_input` is True"
hidden_states = hidden_states * topk_weights.to(hidden_states.dtype) return torch.ops.vllm.rocm_aiter_fused_moe(
topk_ids = topk_ids.to(torch.int32) hidden_states,
topk_weights = torch.ones_like(topk_weights, dtype=torch.float32) w1,
w2,
return torch.ops.vllm.rocm_aiter_ck_moe_2stages( topk_weights,
hidden_states=hidden_states, topk_ids,
w1=w1, quant_method=quant_method,
w2=w2, activation_method=activation_method,
topk_weights=topk_weights, w1_scale=w1_scale,
topk_ids=topk_ids) w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
doweight_stage1=apply_router_weight_on_input)
def rocm_aiter_topk_softmax(topk_weights: torch.Tensor, def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
...@@ -488,14 +348,21 @@ def rocm_aiter_topk_softmax(topk_weights: torch.Tensor, ...@@ -488,14 +348,21 @@ def rocm_aiter_topk_softmax(topk_weights: torch.Tensor,
return topk_weights, topk_indices return topk_weights, topk_indices
def shuffle_weights(*tensors: torch.Tensor, def shuffle_weights(
layout: tuple[int, int]) -> tuple[torch.Tensor, ...]: *tensors: torch.Tensor, layout: tuple[int, int] = (16, 16)
) -> tuple[torch.Tensor, ...]:
""" """
Applies shuffle_weight function from AITER to each Applies shuffle_weight function from AITER to each
input tensor and returns them. input tensor and returns them.
Rearranges (shuffles) the input tensor/s
into a specified block layout for optimized computation.
Args: Args:
*tensors: Variable number of torch.Tensor objects. *tensors: Variable number of torch.Tensor objects.
layout: A pair of integers specifying the
block sizes used to divide the tensors during shuffling.
Default is (16, 16).
Returns: Returns:
A Tuple of shuffled tensors. A Tuple of shuffled tensors.
...@@ -503,25 +370,3 @@ def shuffle_weights(*tensors: torch.Tensor, ...@@ -503,25 +370,3 @@ def shuffle_weights(*tensors: torch.Tensor,
from aiter.ops.shuffle import shuffle_weight from aiter.ops.shuffle import shuffle_weight
return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors) return tuple(shuffle_weight(tensor, layout=layout) for tensor in tensors)
def expand_weights(*tensors: torch.Tensor,
expansion_dims: list[int]) -> tuple[torch.Tensor, ...]:
"""
Expands the dimensions of input tensors.
Args:
*tensors: A variable number of torch.Tensor objects.
expansion_dims: A list of expansion dimensions
corresponding to each tensor.
Returns:
A Tuple of tensors with expanded dimensions.
"""
assert len(tensors) == len(expansion_dims), \
"Number of tensors must match the number of expansion dimensions."
return tuple(
tensor.unsqueeze(-1).unsqueeze(-1).expand((-1, dim, -1))
for tensor, dim in zip(tensors, expansion_dims))
\ No newline at end of file
...@@ -261,6 +261,7 @@ class ReplicatedLinear(LinearBase): ...@@ -261,6 +261,7 @@ class ReplicatedLinear(LinearBase):
quant_config: Quantization configure. quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj) (e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
""" """
def __init__( def __init__(
...@@ -523,6 +524,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -523,6 +524,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
quant_config: Quantization configure. quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj) (e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
""" """
def __init__( def __init__(
...@@ -585,8 +587,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear): ...@@ -585,8 +587,6 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param.shard_id.append(loaded_shard_id) param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container) param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight) param.data_container.append(loaded_weight)
if len(param.data_container) == 2:
self.qweight = param.materialize_nested()
return return
param_data = param.data param_data = param.data
...@@ -805,6 +805,7 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -805,6 +805,7 @@ class QKVParallelLinear(ColumnParallelLinear):
quant_config: Quantization configure. quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.qkv_proj) (e.g. model.layers.0.qkv_proj)
return_bias: If true, return bias together with outputs in forward pass.
""" """
def __init__( def __init__(
...@@ -979,8 +980,6 @@ class QKVParallelLinear(ColumnParallelLinear): ...@@ -979,8 +980,6 @@ class QKVParallelLinear(ColumnParallelLinear):
param.shard_id.append(loaded_shard_id) param.shard_id.append(loaded_shard_id)
param.shard_id_map[loaded_shard_id] = len(param.data_container) param.shard_id_map[loaded_shard_id] = len(param.data_container)
param.data_container.append(loaded_weight) param.data_container.append(loaded_weight)
if len(param.data_container) == 3:
self.qweight = param.materialize_nested()
return return
param_data = param.data param_data = param.data
...@@ -1155,7 +1154,13 @@ class RowParallelLinear(LinearBase): ...@@ -1155,7 +1154,13 @@ class RowParallelLinear(LinearBase):
bias can be fused with other element-wise operations. bias can be fused with other element-wise operations.
We skip adding bias but instead return it. We skip adding bias but instead return it.
params_dtype: Data type for the parameters. params_dtype: Data type for the parameters.
reduce_results: If true, call all-reduce on output and make Y available
to all GPUs, otherwise, every GPU will have its output
which is Y = X_iA_i
quant_config: Quantization configure. quant_config: Quantization configure.
prefix: The name of the layer in the state dict, including all parents
(e.g. model.layers.0.down_proj)
return_bias: If true, return bias together with outputs in forward pass.
""" """
def __init__( def __init__(
......
...@@ -5,10 +5,9 @@ from dataclasses import dataclass ...@@ -5,10 +5,9 @@ from dataclasses import dataclass
import torch import torch
from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.backends.abstract import AttentionMetadata
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.backends.placeholder_attn import ( from vllm.attention.backends.placeholder_attn import (
PlaceholderAttentionMetadata) PlaceholderAttentionMetadata)
from vllm.attention.backends.xformers import XFormersMetadata from vllm.platforms import current_platform
@dataclass @dataclass
...@@ -23,6 +22,21 @@ class Mamba2Metadata: ...@@ -23,6 +22,21 @@ class Mamba2Metadata:
chunk_offsets: torch.Tensor chunk_offsets: torch.Tensor
def get_platform_metadata_classes() -> tuple[type[AttentionMetadata], ...]:
"""Returns the appropriate metadata classes for the current platform."""
if current_platform.is_rocm():
from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata)
return (ROCmFlashAttentionMetadata, PlaceholderAttentionMetadata)
elif current_platform.is_cuda():
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.attention.backends.xformers import XFormersMetadata
return (FlashAttentionMetadata, XFormersMetadata,
PlaceholderAttentionMetadata)
raise ValueError(
f"Unsupported platform for Mamba2: {current_platform.device_type}")
def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor, def _query_start_loc_to_chunk_indices_offsets(query_start_loc: torch.Tensor,
chunk_size: int, chunk_size: int,
total_seqlens: int): total_seqlens: int):
...@@ -78,9 +92,8 @@ def prepare_mamba2_metadata( ...@@ -78,9 +92,8 @@ def prepare_mamba2_metadata(
# Compute seq_idx, chunk_indices and chunk_offsets for prefill only # Compute seq_idx, chunk_indices and chunk_offsets for prefill only
if num_prefills > 0: if num_prefills > 0:
if (isinstance(attn_metadata, attn_metadata_instances = get_platform_metadata_classes()
(FlashAttentionMetadata, XFormersMetadata, if (isinstance(attn_metadata, attn_metadata_instances)
PlaceholderAttentionMetadata))
and attn_metadata.context_lens_tensor is not None): and attn_metadata.context_lens_tensor is not None):
has_initial_states = \ has_initial_states = \
attn_metadata.context_lens_tensor[:num_prefills] > 0 #[batch,] attn_metadata.context_lens_tensor[:num_prefills] > 0 #[batch,]
......
...@@ -34,7 +34,11 @@ from vllm.model_executor.utils import set_weight_attrs ...@@ -34,7 +34,11 @@ from vllm.model_executor.utils import set_weight_attrs
@CustomOp.register("mixer2_gated_rms_norm") @CustomOp.register("mixer2_gated_rms_norm")
class Mixer2RMSNormGated(CustomOp): class Mixer2RMSNormGated(CustomOp):
def __init__(self, full_hidden_size, full_n_groups, eps=1e-6): def __init__(self,
full_hidden_size: int,
full_n_groups: int,
use_rms_norm: bool = True,
eps: float = 1e-6):
super().__init__() super().__init__()
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank() self.tp_rank = get_tensor_model_parallel_rank()
...@@ -44,11 +48,17 @@ class Mixer2RMSNormGated(CustomOp): ...@@ -44,11 +48,17 @@ class Mixer2RMSNormGated(CustomOp):
self.n_groups = full_hidden_size // self.group_size self.n_groups = full_hidden_size // self.group_size
self.variance_epsilon = eps self.variance_epsilon = eps
self.use_rms_norm = use_rms_norm
if self.use_rms_norm:
# Register norm weight only if we're actually applying RMSNorm
self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size)) self.weight = nn.Parameter(torch.ones(self.per_rank_hidden_size))
set_weight_attrs(self.weight, set_weight_attrs(self.weight,
{"weight_loader": sharded_weight_loader(0)}) {"weight_loader": sharded_weight_loader(0)})
assert self.full_hidden_size % self.tp_size== 0,\ else:
"Tensor parallel world size must divide hidden size." # Avoid checkpoint mismatch by skipping unused parameter
self.register_parameter("weight", None)
assert (self.full_hidden_size % self.tp_size == 0
), "Tensor parallel world size must divide hidden size."
def forward_native( def forward_native(
self, self,
...@@ -66,6 +76,8 @@ class Mixer2RMSNormGated(CustomOp): ...@@ -66,6 +76,8 @@ class Mixer2RMSNormGated(CustomOp):
# the input and then redundantly compute the RMSNorm. # the input and then redundantly compute the RMSNorm.
input_dtype = x.dtype input_dtype = x.dtype
x = x * nn.functional.silu(gate.to(torch.float32)) x = x * nn.functional.silu(gate.to(torch.float32))
if not self.use_rms_norm:
return x.to(input_dtype)
if self.n_groups == 1: if self.n_groups == 1:
if self.tp_size > 1: if self.tp_size > 1:
...@@ -74,7 +86,7 @@ class Mixer2RMSNormGated(CustomOp): ...@@ -74,7 +86,7 @@ class Mixer2RMSNormGated(CustomOp):
global_sums = tensor_model_parallel_all_reduce(local_sums) global_sums = tensor_model_parallel_all_reduce(local_sums)
# Calculate the variance # Calculate the variance
count = self.tp_size * x.shape[-1] count = self.tp_size * x.shape[-1]
variance = (global_sums / count) variance = global_sums / count
else: else:
variance = x.pow(2).mean(-1, keepdim=True) variance = x.pow(2).mean(-1, keepdim=True)
...@@ -105,6 +117,11 @@ class Mixer2RMSNormGated(CustomOp): ...@@ -105,6 +117,11 @@ class Mixer2RMSNormGated(CustomOp):
x: torch.Tensor, x: torch.Tensor,
gate: torch.Tensor, gate: torch.Tensor,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
input_dtype = x.dtype
if not self.use_rms_norm:
# Keep gate in float32 for numerical stability during silu
return x * nn.functional.silu(gate.to(
torch.float32)).to(input_dtype)
if self.tp_size > 1 or self.n_groups != 1: if self.tp_size > 1 or self.n_groups != 1:
return self.forward_native(x, gate) return self.forward_native(x, gate)
...@@ -182,13 +199,15 @@ def mamba_v2_sharded_weight_loader( ...@@ -182,13 +199,15 @@ def mamba_v2_sharded_weight_loader(
# seem to handle slices well. # seem to handle slices well.
# https://github.com/python/mypy/issues/2410 # https://github.com/python/mypy/issues/2410
param.data[ param.data[
boundary:(boundary + take), # type: ignore[misc] boundary:(boundary + take),
...] = loaded_weight[loaded_start_idx:( # type: ignore[misc] ... # type: ignore[misc]
loaded_start_idx + take)] # type: ignore[misc] ] = loaded_weight[loaded_start_idx:(loaded_start_idx +
take) # type: ignore[misc]
] # type: ignore[misc]
# move indexing boundaries # move indexing boundaries
boundary += shard_size boundary += shard_size
loaded_boundary += (full_dim - extra) loaded_boundary += full_dim - extra
return loader return loader
...@@ -206,7 +225,8 @@ class MambaMixer2(CustomOp): ...@@ -206,7 +225,8 @@ class MambaMixer2(CustomOp):
**selective** state spaces) **selective** state spaces)
""" """
def __init__(self, def __init__(
self,
hidden_size: int, hidden_size: int,
ssm_state_size: int, ssm_state_size: int,
conv_kernel_size: int, conv_kernel_size: int,
...@@ -217,8 +237,10 @@ class MambaMixer2(CustomOp): ...@@ -217,8 +237,10 @@ class MambaMixer2(CustomOp):
num_heads: int = 128, num_heads: int = 128,
head_dim: int = 64, head_dim: int = 64,
rms_norm_eps: float = 1e-5, rms_norm_eps: float = 1e-5,
activation="silu", activation: str = "silu",
quant_config: Optional[QuantizationConfig] = None): use_rms_norm: bool = True,
quant_config: Optional[QuantizationConfig] = None,
):
super().__init__() super().__init__()
# For TP, the sharding plan is as follows: # For TP, the sharding plan is as follows:
...@@ -238,17 +260,16 @@ class MambaMixer2(CustomOp): ...@@ -238,17 +260,16 @@ class MambaMixer2(CustomOp):
self.tp_size = get_tensor_model_parallel_world_size() self.tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank() tp_rank = get_tensor_model_parallel_rank()
assert num_heads % self.tp_size == 0, \ assert (num_heads % self.tp_size == 0
"Tensor parallel world size must divide num heads." ), "Tensor parallel world size must divide num heads."
assert (n_groups % self.tp_size) == 0 or n_groups == 1, \ assert (n_groups % self.tp_size) == 0 or n_groups == 1, (
(
"If tensor parallel world size does not divide num_heads, " "If tensor parallel world size does not divide num_heads, "
"then num_groups must equal 1." "then num_groups must equal 1.")
)
assert self.tp_size == 1 or quant_config is None, \ assert (
"Tensor parallel currently not supported for quantized models." self.tp_size == 1 or quant_config is None
), "Tensor parallel currently not supported for quantized models."
self.ssm_state_size = ssm_state_size self.ssm_state_size = ssm_state_size
self.activation = activation self.activation = activation
...@@ -265,8 +286,7 @@ class MambaMixer2(CustomOp): ...@@ -265,8 +286,7 @@ class MambaMixer2(CustomOp):
self.n_groups = n_groups + extra_groups_for_head_shards( self.n_groups = n_groups + extra_groups_for_head_shards(
n_groups, self.tp_size) n_groups, self.tp_size)
self.conv_dim = (intermediate_size + self.conv_dim = intermediate_size + 2 * self.n_groups * ssm_state_size
2 * self.n_groups * ssm_state_size)
self.conv1d = ColumnParallelLinear( self.conv1d = ColumnParallelLinear(
input_size=conv_kernel_size, input_size=conv_kernel_size,
output_size=self.conv_dim, output_size=self.conv_dim,
...@@ -279,11 +299,12 @@ class MambaMixer2(CustomOp): ...@@ -279,11 +299,12 @@ class MambaMixer2(CustomOp):
# doesn't allow to override it # doesn't allow to override it
self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1) self.conv1d.weight.data = self.conv1d.weight.data.unsqueeze(1)
self.in_proj = ColumnParallelLinear(input_size=hidden_size, self.in_proj = ColumnParallelLinear(
output_size=intermediate_size + input_size=hidden_size,
self.conv_dim + self.num_heads, output_size=intermediate_size + self.conv_dim + self.num_heads,
bias=use_bias, bias=use_bias,
quant_config=quant_config) quant_config=quant_config,
)
# - because in_proj is a concatenation of 3 weights, we # - because in_proj is a concatenation of 3 weights, we
# need to interleave them before sharding # need to interleave them before sharding
...@@ -305,7 +326,8 @@ class MambaMixer2(CustomOp): ...@@ -305,7 +326,8 @@ class MambaMixer2(CustomOp):
# - ditto for the otther two weights below # - ditto for the otther two weights below
delattr(self.conv1d.bias, "weight_loader") delattr(self.conv1d.bias, "weight_loader")
set_weight_attrs( set_weight_attrs(
self.conv1d.bias, { self.conv1d.bias,
{
"weight_loader": "weight_loader":
mamba_v2_sharded_weight_loader( mamba_v2_sharded_weight_loader(
[ [
...@@ -316,18 +338,25 @@ class MambaMixer2(CustomOp): ...@@ -316,18 +338,25 @@ class MambaMixer2(CustomOp):
self.tp_size, self.tp_size,
tp_rank, tp_rank,
) )
}) },
)
delattr(self.conv1d.weight, "weight_loader") delattr(self.conv1d.weight, "weight_loader")
set_weight_attrs( set_weight_attrs(
self.conv1d.weight, { self.conv1d.weight,
{
"weight_loader": "weight_loader":
mamba_v2_sharded_weight_loader([ mamba_v2_sharded_weight_loader(
[
intermediate_settings, intermediate_settings,
group_shard_settings, group_shard_settings,
group_shard_settings, group_shard_settings,
], self.tp_size, tp_rank) ],
}) self.tp_size,
tp_rank,
)
},
)
if quant_config is None: if quant_config is None:
# - quant layers do not have a weight loader # - quant layers do not have a weight loader
...@@ -345,8 +374,10 @@ class MambaMixer2(CustomOp): ...@@ -345,8 +374,10 @@ class MambaMixer2(CustomOp):
head_setings, # for dt head_setings, # for dt
], ],
self.tp_size, self.tp_size,
tp_rank) tp_rank,
}) )
},
)
# - these are TPed by heads to reduce the size of the # - these are TPed by heads to reduce the size of the
# temporal shape # temporal shape
...@@ -357,6 +388,7 @@ class MambaMixer2(CustomOp): ...@@ -357,6 +388,7 @@ class MambaMixer2(CustomOp):
)) ))
self.D = nn.Parameter(torch.ones(num_heads // self.tp_size)) self.D = nn.Parameter(torch.ones(num_heads // self.tp_size))
self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size)) self.dt_bias = nn.Parameter(torch.ones(num_heads // self.tp_size))
self.use_rms_norm = use_rms_norm
set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)}) set_weight_attrs(self.D, {"weight_loader": sharded_weight_loader(0)})
a_weight_loader = composed_weight_loader( a_weight_loader = composed_weight_loader(
...@@ -365,18 +397,25 @@ class MambaMixer2(CustomOp): ...@@ -365,18 +397,25 @@ class MambaMixer2(CustomOp):
set_weight_attrs(self.dt_bias, set_weight_attrs(self.dt_bias,
{"weight_loader": sharded_weight_loader(0)}) {"weight_loader": sharded_weight_loader(0)})
self.out_proj = RowParallelLinear(intermediate_size, self.out_proj = RowParallelLinear(
intermediate_size,
hidden_size, hidden_size,
bias=use_bias, bias=use_bias,
input_is_parallel=True, input_is_parallel=True,
quant_config=quant_config) quant_config=quant_config,
)
self.norm = Mixer2RMSNormGated(intermediate_size, self.norm = Mixer2RMSNormGated(intermediate_size,
n_groups, n_groups,
self.use_rms_norm,
eps=rms_norm_eps) eps=rms_norm_eps)
def forward_native(self, hidden_states: torch.Tensor, def forward_native(
conv_state: torch.Tensor, ssm_state: torch.Tensor): self,
hidden_states: torch.Tensor,
conv_state: torch.Tensor,
ssm_state: torch.Tensor,
):
pass pass
def forward_cuda( def forward_cuda(
...@@ -384,6 +423,7 @@ class MambaMixer2(CustomOp): ...@@ -384,6 +423,7 @@ class MambaMixer2(CustomOp):
hidden_states: torch.Tensor, hidden_states: torch.Tensor,
mamba_cache_params: MambaCacheParams, mamba_cache_params: MambaCacheParams,
mamba2_metadata: Mamba2Metadata, mamba2_metadata: Mamba2Metadata,
mup_vector: Optional[torch.Tensor] = None,
): ):
# mamba2_metadata contains metadata necessary for the mamba2 triton # mamba2_metadata contains metadata necessary for the mamba2 triton
# kernels to operate in continuous batching and in chunked prefill # kernels to operate in continuous batching and in chunked prefill
...@@ -401,6 +441,10 @@ class MambaMixer2(CustomOp): ...@@ -401,6 +441,10 @@ class MambaMixer2(CustomOp):
# 1. Gated MLP's linear projection # 1. Gated MLP's linear projection
projected_states, _ = self.in_proj(hidden_states) projected_states, _ = self.in_proj(hidden_states)
if mup_vector is not None:
projected_states = projected_states * mup_vector
gate, hidden_states_B_C, dt = torch.split( gate, hidden_states_B_C, dt = torch.split(
projected_states, projected_states,
[ [
...@@ -561,6 +605,9 @@ class MambaMixer2(CustomOp): ...@@ -561,6 +605,9 @@ class MambaMixer2(CustomOp):
hidden_states = torch.vstack(ssd_output_list) hidden_states = torch.vstack(ssd_output_list)
# 4. gated MLP # 4. gated MLP
# GatedRMSNorm internally applying SiLU to the gate
# SiLU is applied internally before normalization, unlike standard
# norm usage
hidden_states = self.norm(hidden_states, gate) hidden_states = self.norm(hidden_states, gate)
# 5. Final linear projection # 5. Final linear projection
......
...@@ -14,7 +14,7 @@ QuantizationMethods = Literal[ ...@@ -14,7 +14,7 @@ QuantizationMethods = Literal[
"ptpc_fp8", "ptpc_fp8",
"fbgemm_fp8", "fbgemm_fp8",
"modelopt", "modelopt",
"nvfp4", "modelopt_fp4",
"marlin", "marlin",
"bitblas", "bitblas",
"gguf", "gguf",
...@@ -120,7 +120,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]: ...@@ -120,7 +120,7 @@ def get_quantization_config(quantization: str) -> type[QuantizationConfig]:
"fp8": Fp8Config, "fp8": Fp8Config,
"fbgemm_fp8": FBGEMMFp8Config, "fbgemm_fp8": FBGEMMFp8Config,
"modelopt": ModelOptFp8Config, "modelopt": ModelOptFp8Config,
"nvfp4": ModelOptNvFp4Config, "modelopt_fp4": ModelOptNvFp4Config,
"marlin": MarlinConfig, "marlin": MarlinConfig,
"bitblas": BitBLASConfig, "bitblas": BitBLASConfig,
"gguf": GGUFConfig, "gguf": GGUFConfig,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment