Unverified Commit c5752323 authored by Lu Fang's avatar Lu Fang Committed by GitHub
Browse files

[Model] Support Llama4 in vLLM (#16104)

parent 63375f0c
......@@ -553,6 +553,9 @@ def main(args: argparse.Namespace):
intermediate_size = config.moe_intermediate_size
shard_intermediate_size = 2 * intermediate_size // args.tp_size
else:
if not hasattr(config, "hidden_size"):
# Support for llama4
config = config.text_config
# Default: Mixtral.
E = config.num_local_experts
topk = config.num_experts_per_tok
......
......@@ -840,6 +840,13 @@ See [this page](#generative-models) for more information on how to use generativ
*
* ✅︎
* ✅︎
- * `Llama4ForConditionalGeneration`
* Llama-4-17B-Omni-Instruct
* T + I<sup>+</sup>
* `meta-llama/Llama-4-Scout-17B-16E-Instruct`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8`, `meta-llama/Llama-4-Maverick-17B-128E-Instruct`, etc.
*
*
* ✅︎
- * `LlavaForConditionalGeneration`
* LLaVA-1.5
* T + I<sup>E+</sup>
......
......@@ -582,6 +582,42 @@ def run_mllama(questions: list[str], modality: str) -> ModelRequestData:
)
def run_llama4(questions: list[str], modality: str):
assert modality == "image"
model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
max_num_seqs=4,
tensor_parallel_size=8,
disable_mm_preprocessor_cache=args.disable_mm_preprocessor_cache,
gpu_memory_utilization=0.4,
)
tokenizer = AutoTokenizer.from_pretrained(model_name)
messages = [[{
"role":
"user",
"content": [{
"type": "image"
}, {
"type": "text",
"text": f"{question}"
}]
}] for question in questions]
prompts = tokenizer.apply_chat_template(messages,
add_generation_prompt=True,
tokenize=False)
stop_token_ids = None
return ModelRequestData(
engine_args=engine_args,
prompts=prompts,
stop_token_ids=stop_token_ids,
)
# Molmo
def run_molmo(questions: list[str], modality: str) -> ModelRequestData:
assert modality == "image"
......@@ -907,6 +943,7 @@ model_example_map = {
"minicpmv": run_minicpmv,
"mistral3": run_mistral3,
"mllama": run_mllama,
"llama4": run_llama4,
"molmo": run_molmo,
"NVLM_D": run_nvlm_d,
"paligemma": run_paligemma,
......
......@@ -253,6 +253,43 @@ def load_internvl(question: str, image_urls: list[str]) -> ModelRequestData:
)
def load_llama4(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
engine_args = EngineArgs(
model=model_name,
max_model_len=8192,
max_num_seqs=4,
tensor_parallel_size=8,
limit_mm_per_prompt={"image": len(image_urls)},
)
placeholders = [{"type": "image", "image": url} for url in image_urls]
messages = [{
"role":
"user",
"content": [
*placeholders,
{
"type": "text",
"text": question
},
],
}]
processor = AutoProcessor.from_pretrained(model_name)
prompt = processor.apply_chat_template(messages,
tokenize=False,
add_generation_prompt=True)
return ModelRequestData(
engine_args=engine_args,
prompt=prompt,
image_data=[fetch_image(url) for url in image_urls],
)
def load_mistral3(question: str, image_urls: list[str]) -> ModelRequestData:
model_name = "mistralai/Mistral-Small-3.1-24B-Instruct-2503"
......@@ -567,6 +604,7 @@ model_example_map = {
"h2ovl_chat": load_h2ovl,
"idefics3": load_idefics3,
"internvl_chat": load_internvl,
"llama4": load_llama4,
"mistral3": load_mistral3,
"mllama": load_mllama,
"NVLM_D": load_nvlm_d,
......
......@@ -6,7 +6,7 @@ requests >= 2.26.0
tqdm
blake3
py-cpuinfo
transformers >= 4.50.3
transformers >= 4.51.0
huggingface-hub[hf_xet] >= 0.30.0 # Required for Xet downloads.
tokenizers >= 0.19.1 # Required for Llama 3.
protobuf # Required by LlamaTokenizer.
......
......@@ -30,7 +30,7 @@ mistral_common[opencv] >= 1.5.4 # required for pixtral test
opencv-python-headless >= 4.11.0 # required for video test
datamodel_code_generator # required for minicpm3 test
lm-eval[api]==0.4.8 # required for model evaluation test
transformers==4.50.3
transformers==4.51.0
huggingface-hub[hf_xet]>=0.30.0 # Required for Xet downloads.
# quantization
bitsandbytes>=0.45.3
......
......@@ -645,7 +645,7 @@ tqdm==4.66.6
# transformers
tqdm-multiprocess==0.0.11
# via lm-eval
transformers==4.50.3
transformers==4.51.0
# via
# -r requirements/test.in
# genai-perf
......
......@@ -536,6 +536,22 @@ VLM_TEST_SETTINGS = {
limit_mm_per_prompt={"image": 1},
)],
),
"llama4": VLMTestInfo(
models=["meta-llama/Llama-4-Scout-17B-16E-Instruct"],
prompt_formatter=lambda img_prompt: f"<|begin_of_text|><|header_start|>user<|header_end|>\n\n{img_prompt}<|eot|><|header_start|>assistant<|header_end|>\n\n", # noqa: E501
img_idx_to_prompt=lambda _: "<|image|>",
test_type=(VLMTestType.IMAGE, VLMTestType.MULTI_IMAGE),
distributed_executor_backend="mp",
image_size_factors=[(.25, 0.5, 1.0)],
hf_model_kwargs={"device_map": "auto"},
max_model_len=8192,
max_num_seqs=4,
dtype="bfloat16",
auto_cls=AutoModelForImageTextToText,
tensor_parallel_size=8,
vllm_runner_kwargs={"gpu_memory_utilization": 0.8},
marks=[large_gpu_mark(min_gb=80), multi_gpu_marks(num_gpus=8)],
),
}
# yapf: enable
......
......@@ -280,6 +280,7 @@ def _test_processing_correctness_mistral(
"Skywork/Skywork-R1V-38B",
"fixie-ai/ultravox-v0_5-llama-3_2-1b",
"openai/whisper-large-v3",
"meta-llama/Llama-4-Scout-17B-16E-Instruct",
])
@pytest.mark.parametrize("hit_rate", [0.3, 0.5, 1.0])
@pytest.mark.parametrize("num_batches", [32])
......
# SPDX-License-Identifier: Apache-2.0
"""Tests for Llama4's multimodal preprocessing kwargs."""
import pytest
from vllm.multimodal import MULTIMODAL_REGISTRY
from vllm.transformers_utils.tokenizer import encode_tokens
from ....conftest import _ImageAssets
from ...utils import build_model_context
@pytest.mark.parametrize("model_id",
["meta-llama/Llama-4-Scout-17B-16E-Instruct"])
@pytest.mark.parametrize("mm_processor_kwargs", [{}])
@pytest.mark.parametrize("num_imgs", [1, 5])
@pytest.mark.parametrize("disable_mm_preprocessor_cache", [True, False])
@pytest.mark.parametrize("tokenized_prompt", [True, False])
def test_processor_override(
image_assets: _ImageAssets,
model_id: str,
mm_processor_kwargs: dict,
num_imgs: int,
disable_mm_preprocessor_cache: bool,
tokenized_prompt: bool,
):
"""Ensure llama4 processor works properly."""
ctx = build_model_context(
model_id,
mm_processor_kwargs=mm_processor_kwargs,
limit_mm_per_prompt={"image": num_imgs},
disable_mm_preprocessor_cache=disable_mm_preprocessor_cache,
)
processor = MULTIMODAL_REGISTRY.create_processor(ctx.model_config)
config = processor.info.get_hf_config()
tokenizer = processor.info.get_tokenizer()
hf_processor = processor.info.get_hf_processor()
vocab = tokenizer.get_vocab()
prompt = "<|begin_of_text|><|header_start|>user<|header_end|>" \
+ "<|image|>" * num_imgs \
+ "<|eot|><|header_start|>assistant<|header_end|>"
mm_data = {
"image": [
image_assets[(i % len(image_assets))].pil_image
for i in range(num_imgs)
]
}
if tokenized_prompt:
prompt = encode_tokens(tokenizer, prompt)
processed_inputs = processor.apply(prompt, mm_data, mm_processor_kwargs)
mm_kwargs = processed_inputs["mm_kwargs"]
# place holder replacements
prompt_token_ids = processed_inputs["prompt_token_ids"]
assert prompt_token_ids.count(config.boi_token_index) == num_imgs
assert prompt_token_ids.count(config.eoi_token_index) == num_imgs
assert prompt_token_ids.count(vocab[hf_processor.image_token]) == num_imgs
aspect_ratios = mm_kwargs["aspect_ratios"]
num_x_separators = num_y_separators = 0
for tiles_y, tiles_x in aspect_ratios:
if tiles_x * tiles_y > 1:
num_x_separators += (tiles_x - 1) * tiles_y
num_y_separators += tiles_y
assert prompt_token_ids.count(vocab[hf_processor.tile_token]) \
== num_x_separators
assert prompt_token_ids.count(vocab[hf_processor.tile_global_token]) \
== num_y_separators
# image token offsets
img_locs = processed_inputs["mm_placeholders"].get("image", [])
assert len(img_locs) == num_imgs
assert [img_loc["offset"] for img_loc in img_locs] == \
[i for i, v in enumerate(prompt_token_ids) \
if v == config.boi_token_index]
# patch sizes and masks
assert prompt_token_ids.count(config.image_token_index) \
== sum(img_patch.sum() for img_patch in mm_kwargs["embed_is_patch"])
patch_token_id = vocab[hf_processor.img_patch_token]
num_patches = processed_inputs["prompt_token_ids"].count(patch_token_id)
mm_counts = {"image": num_imgs}
assert num_patches / num_imgs <= \
processor.info.get_mm_max_tokens_per_item(32768, mm_counts)["image"]
num_patches_per_chunk = processor.info.get_patch_per_chunk(
config.vision_config)
assert prompt_token_ids.count(config.image_token_index) \
== mm_kwargs["patches_per_image"].sum() * num_patches_per_chunk
assert mm_kwargs["pixel_values"].shape[0] \
== mm_kwargs["patches_per_image"].sum()
for embed_is_patch, aspect_ratio in zip(mm_kwargs["embed_is_patch"],
mm_kwargs["aspect_ratios"]):
assert embed_is_patch.shape[0] == \
len(tokenizer.encode(
hf_processor._prompt_split_image(
aspect_ratio, num_patches_per_chunk),
add_special_tokens=False))
......@@ -337,6 +337,7 @@ _MULTIMODAL_EXAMPLE_MODELS = {
tokenizer="facebook/bart-base",
trust_remote_code=True), # noqa: E501
"MllamaForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-3.2-11B-Vision-Instruct"), # noqa: E501
"Llama4ForConditionalGeneration": _HfExamplesInfo("meta-llama/Llama-4-Scout-17B-16E-Instruct"), # noqa: E501
"WhisperForConditionalGeneration": _HfExamplesInfo("openai/whisper-large-v3"), # noqa: E501
}
......
......@@ -23,6 +23,11 @@ from .registry import HF_EXAMPLE_MODELS
@pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs())
def test_registry_imports(model_arch):
# Llama4ForCausalLM does not have a standalone model
if model_arch == "Llama4ForCausalLM":
return
model_info = HF_EXAMPLE_MODELS.get_hf_info(model_arch)
model_info.check_transformers_version(on_fail="skip")
......@@ -91,9 +96,12 @@ def test_registry_is_pp(model_arch, is_pp, init_cuda):
def test_hf_registry_coverage():
untested_archs = (ModelRegistry.get_supported_archs() -
untested_archs = set(ModelRegistry.get_supported_archs() -
HF_EXAMPLE_MODELS.get_supported_archs())
# Llama4ForCausalLM does not have a standalone model
untested_archs.discard("Llama4ForCausalLM")
assert not untested_archs, (
"Please add the following architectures to "
f"`tests/models/registry.py`: {untested_archs}")
......@@ -354,6 +354,8 @@ class ModelConfig:
self.hf_config = hf_config
self.hf_text_config = get_hf_text_config(self.hf_config)
self.attention_chunk_size = getattr(self.hf_text_config,
"attention_chunk_size", None)
self.encoder_config = self._get_encoder_config()
self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision)
......
......@@ -500,7 +500,7 @@ class BaseMultiModalItemTracker(ABC, Generic[_T]):
"internvl_chat", "skywork_chat", "NVLM_D",
"h2ovl_chat"):
return "<image>"
if model_type == "mllama":
if model_type in ("mllama", "llama4"):
return "<|image|>"
if model_type in ("qwen2_vl", "qwen2_5_vl"):
return "<|vision_start|><|image_pad|><|vision_end|>"
......
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
}
......@@ -23,6 +23,7 @@ def cutlass_moe_fp8(
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
out_dtype: torch.dtype = torch.half,
apply_router_weight_on_input: bool = False,
) -> torch.Tensor:
"""
This function computes a a8w8-quantized Mixture of Experts (MoE) layer
......@@ -96,8 +97,14 @@ def cutlass_moe_fp8(
n = w2_q.size(1)
topk = topk_ids.size(1)
per_act_token = a1_scale.numel() != 1 if a1_scale is not None else (
a2_scale.numel() != 1 if a2_scale is not None else False)
if apply_router_weight_on_input:
assert topk == 1, \
"apply_router_weight_on_input is only implemented for topk=1"
# TODO: this only works for topK=1, will need to update for topK>1
a = a * topk_weights.to(out_dtype)
a_q, a1_scale = ops.scaled_fp8_quant(
a, a1_scale, use_per_token_if_dynamic=per_act_token)
......@@ -139,6 +146,8 @@ def cutlass_moe_fp8(
ops.cutlass_moe_mm(c2, intemediate_q, w2_q, a2_scale, w2_scale,
expert_offsets[:-1], problem_sizes2, ab_strides2,
ab_strides2, c_strides2)
return (c2[c_map].view(m, topk, k) *
topk_weights.view(m, topk, 1).to(out_dtype)).sum(dim=1)
# Gather tokens
c2 = c2[c_map].view(m, topk, k)
if not apply_router_weight_on_input:
c2 = c2 * topk_weights.view(m, topk, 1).to(out_dtype)
return c2.sum(dim=1)
......@@ -954,6 +954,7 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
......@@ -967,10 +968,10 @@ def inplace_fused_experts(hidden_states: torch.Tensor,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> None:
fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids, True,
activation, use_fp8_w8a8, use_int8_w8a16,
use_int4_w4a16, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape)
activation, apply_router_weight_on_input, use_fp8_w8a8,
use_int8_w8a16, use_int4_w4a16, global_num_experts,
expert_map, w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
a2_scale, block_shape)
def inplace_fused_experts_fake(
......@@ -980,6 +981,7 @@ def inplace_fused_experts_fake(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
......@@ -1010,6 +1012,7 @@ def outplace_fused_experts(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
......@@ -1023,10 +1026,11 @@ def outplace_fused_experts(
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[List[int]] = None) -> torch.Tensor:
return fused_experts_impl(hidden_states, w1, w2, topk_weights, topk_ids,
False, activation, use_fp8_w8a8, use_int8_w8a16,
use_int4_w4a16, global_num_experts, expert_map,
w1_scale, w2_scale, w1_zp, w2_zp, a1_scale,
a2_scale, block_shape)
False, activation, apply_router_weight_on_input,
use_fp8_w8a8, use_int8_w8a16, use_int4_w4a16,
global_num_experts, expert_map, w1_scale,
w2_scale, w1_zp, w2_zp, a1_scale, a2_scale,
block_shape)
def outplace_fused_experts_fake(
......@@ -1084,6 +1088,7 @@ def fused_experts(hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
......@@ -1099,6 +1104,7 @@ def fused_experts(hidden_states: torch.Tensor,
allow_deep_gemm: bool = False) -> torch.Tensor:
if (allow_deep_gemm and use_fp8_w8a8
and _valid_deep_gemm(hidden_states, w1, w2, expert_map)):
assert apply_router_weight_on_input is False
return deep_gemm_moe_fp8(
hidden_states=hidden_states,
w1=w1,
......@@ -1122,6 +1128,7 @@ def fused_experts(hidden_states: torch.Tensor,
topk_weights=topk_weights,
topk_ids=topk_ids,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16,
use_int4_w4a16=use_int4_w4a16,
......@@ -1143,6 +1150,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
topk_ids: torch.Tensor,
inplace: bool = False,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
use_fp8_w8a8: bool = False,
use_int8_w8a16: bool = False,
use_int4_w4a16: bool = False,
......@@ -1270,7 +1278,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
False,
apply_router_weight_on_input,
top_k_num,
config,
compute_type=compute_type,
......@@ -1307,7 +1315,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
sorted_token_ids,
expert_ids,
num_tokens_post_padded,
True,
not apply_router_weight_on_input,
1,
config,
compute_type=compute_type,
......
......@@ -65,7 +65,9 @@ class FusedMoEMethodBase(QuantizeMethodBase):
expert_map: Optional[torch.Tensor] = None,
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
raise NotImplementedError
......@@ -156,9 +158,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
return self.forward(x=x,
return self.forward(
x=x,
layer=layer,
router_logits=router_logits,
top_k=top_k,
......@@ -171,7 +175,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function=custom_routing_function,
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias,
activation=activation)
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input)
def forward_cuda(
self,
......@@ -188,6 +193,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
topk_weights, topk_ids = FusedMoE.select_experts(
......@@ -202,13 +208,15 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
scoring_func=scoring_func,
e_score_correction_bias=e_score_correction_bias)
return fused_experts(hidden_states=x,
return fused_experts(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map)
......@@ -228,9 +236,11 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
activation: str = "silu",
apply_router_weight_on_input: bool = False,
**kwargs,
):
assert activation == "silu", f"{activation} is not supported."
assert apply_router_weight_on_input is False
return layer.ipex_fusion(
x,
use_grouped_topk,
......@@ -259,6 +269,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert not use_grouped_topk
......@@ -266,6 +277,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
assert topk_group is None
assert custom_routing_function is None
assert layer is not None
assert apply_router_weight_on_input is False
if scoring_func != "softmax":
raise NotImplementedError(
"Only softmax scoring function is supported for HPU.")
......@@ -290,12 +302,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert not use_grouped_topk
assert num_expert_group is None
assert topk_group is None
assert custom_routing_function is None
assert apply_router_weight_on_input is False
if scoring_func != "softmax":
raise NotImplementedError(
"Only softmax scoring function is supported for TPU.")
......@@ -401,6 +415,7 @@ class FusedMoE(torch.nn.Module):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
):
super().__init__()
......@@ -486,6 +501,7 @@ class FusedMoE(torch.nn.Module):
self.quant_method = quant_config.get_quant_method(self, prefix)
assert self.quant_method is not None
self.apply_router_weight_on_input = apply_router_weight_on_input
moe_quant_params = {
"num_experts": self.local_num_experts,
"hidden_size": hidden_size,
......@@ -853,6 +869,7 @@ class FusedMoE(torch.nn.Module):
scoring_func=self.scoring_func,
e_score_correction_bias=self.e_score_correction_bias,
activation=self.activation,
apply_router_weight_on_input=self.apply_router_weight_on_input,
)
if self.dp_size > 1:
......
......@@ -92,6 +92,7 @@ class RMSNorm(CustomOp):
eps: float = 1e-6,
var_hidden_size: Optional[int] = None,
has_weight: bool = True,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__()
......@@ -100,7 +101,9 @@ class RMSNorm(CustomOp):
self.variance_size_override = (None if var_hidden_size == hidden_size
else var_hidden_size)
self.has_weight = has_weight
if dtype is not None:
self.weight = torch.ones(hidden_size, dtype=dtype)
else:
self.weight = torch.ones(hidden_size)
if self.has_weight:
self.weight = nn.Parameter(self.weight)
......
......@@ -469,6 +469,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
custom_routing_function: Optional[Callable] = None,
scoring_func: str = "softmax",
e_score_correction_bias: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
) -> torch.Tensor:
assert activation == "silu", "Only SiLU activation is supported."
......@@ -476,6 +477,10 @@ class AWQMoEMethod(FusedMoEMethodBase):
raise NotImplementedError(
"Expert Parallelism is not supported for "
"fused Marlin MoE method.")
if apply_router_weight_on_input:
raise NotImplementedError(
"Apply router weight on input is not supported for"
"fused Marlin MoE method.")
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment