"vscode:/vscode.git/clone" did not exist on "8f55962a7ffcf310ecb462f3a547e593e4fd77bd"
Commit 8db76782 authored by 王敏's avatar 王敏
Browse files

[fix]1.修复medusa、scorer等并行解码单测;2.修复moe...

[fix]1.修复medusa、scorer等并行解码单测;2.修复moe kernel单测问题,优化代码;3.修复rejection_sampler中test_compare_nonflashinfer_backend单测问题
parent a9b29641
...@@ -13,6 +13,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock ...@@ -13,6 +13,7 @@ from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock
import vllm.model_executor.layers.fused_moe # noqa import vllm.model_executor.layers.fused_moe # noqa
from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe, from tests.kernels.utils import (opcheck, stack_and_dev, torch_moe,
torch_moe_single) torch_moe_single)
from vllm.config import VllmConfig, set_current_vllm_config
from vllm.model_executor.layers.fused_moe import fused_moe from vllm.model_executor.layers.fused_moe import fused_moe
from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk from vllm.model_executor.layers.fused_moe.fused_moe import fused_topk
from vllm.model_executor.layers.fused_moe.moe_torch_iterative import ( from vllm.model_executor.layers.fused_moe.moe_torch_iterative import (
...@@ -29,6 +30,9 @@ NUM_EXPERTS = [8, 64] ...@@ -29,6 +30,9 @@ NUM_EXPERTS = [8, 64]
EP_SIZE = [1, 4] EP_SIZE = [1, 4]
TOP_KS = [2, 6] TOP_KS = [2, 6]
vllm_config = VllmConfig()
vllm_config.scheduler_config.max_num_seqs = 128
vllm_config.scheduler_config.max_model_len = 8192
@pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128]) @pytest.mark.parametrize("m", [1, 33, 64, 222, 1024 * 128])
@pytest.mark.parametrize("n", [128, 1024, 2048]) @pytest.mark.parametrize("n", [128, 1024, 2048])
...@@ -67,6 +71,7 @@ def test_fused_moe( ...@@ -67,6 +71,7 @@ def test_fused_moe(
else: else:
e_map = None e_map = None
with set_current_vllm_config(vllm_config):
torch_output = torch_moe(a, w1, w2, score, topk, e_map) torch_output = torch_moe(a, w1, w2, score, topk, e_map)
iterative_output = iterative_moe(a, iterative_output = iterative_moe(a,
w1, w1,
...@@ -92,6 +97,7 @@ def test_fused_moe( ...@@ -92,6 +97,7 @@ def test_fused_moe(
global_num_experts=e, global_num_experts=e,
expert_map=e_map, expert_map=e_map,
renormalize=False) renormalize=False)
torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0) torch.testing.assert_close(triton_output, torch_output, atol=2e-2, rtol=0)
torch.testing.assert_close(iterative_output, torch.testing.assert_close(iterative_output,
torch_output, torch_output,
......
...@@ -276,6 +276,8 @@ def test_mixed_seeded_batch(k: int, vocab_size: int, batch_size: int, ...@@ -276,6 +276,8 @@ def test_mixed_seeded_batch(k: int, vocab_size: int, batch_size: int,
assert torch.equal(batch_result[i], results[i].squeeze(0)) assert torch.equal(batch_result[i], results[i].squeeze(0))
@pytest.mark.skipif(current_platform.is_rocm(),
reason="Rocm platform does not support flashinfer.")
@pytest.mark.parametrize("k", [1, 3, 6]) @pytest.mark.parametrize("k", [1, 3, 6])
@pytest.mark.parametrize("vocab_size", [30_000, 50_000]) @pytest.mark.parametrize("vocab_size", [30_000, 50_000])
@pytest.mark.parametrize("batch_size", [1, 8, 32, 128]) @pytest.mark.parametrize("batch_size", [1, 8, 32, 128])
......
...@@ -11,6 +11,7 @@ from ..utils import maybe_enable_chunked_prefill ...@@ -11,6 +11,7 @@ from ..utils import maybe_enable_chunked_prefill
from .conftest import run_equality_correctness_test from .conftest import run_equality_correctness_test
from ...utils import models_path_prefix from ...utils import models_path_prefix
os.environ["LLAMA_NN"] = "0"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
...@@ -44,7 +45,7 @@ from ...utils import models_path_prefix ...@@ -44,7 +45,7 @@ from ...utils import models_path_prefix
]) ])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("logprobs", [1, 6]) @pytest.mark.parametrize("logprobs", [1, 6])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4, 12]) @pytest.mark.parametrize("prefill_chunk_size", [-1, 4, 8])
def test_logprobs_equality(vllm_runner, common_llm_kwargs, def test_logprobs_equality(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, baseline_llm_kwargs, per_test_common_llm_kwargs, baseline_llm_kwargs,
test_llm_kwargs, batch_size: int, output_len: int, test_llm_kwargs, batch_size: int, output_len: int,
......
...@@ -27,6 +27,8 @@ from ..utils import maybe_enable_chunked_prefill ...@@ -27,6 +27,8 @@ from ..utils import maybe_enable_chunked_prefill
from .conftest import run_equality_correctness_test from .conftest import run_equality_correctness_test
from ...utils import models_path_prefix from ...utils import models_path_prefix
os.environ["LLAMA_NN"] = "0"
# main model # main model
# lmsys/vicuna-7b-v1.3 was to be used but it's causing # lmsys/vicuna-7b-v1.3 was to be used but it's causing
# OOM in CI pipeline, so using a smaller model. # OOM in CI pipeline, so using a smaller model.
...@@ -57,6 +59,9 @@ PRECISION = "float16" ...@@ -57,6 +59,9 @@ PRECISION = "float16"
# Main model # Main model
"model_name": MAIN_MODEL, "model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.6
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
...@@ -71,9 +76,9 @@ PRECISION = "float16" ...@@ -71,9 +76,9 @@ PRECISION = "float16"
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
128, 128,
]) ])
@pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32]) @pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
per_test_common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, baseline_llm_kwargs, test_llm_kwargs,
...@@ -106,6 +111,9 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, ...@@ -106,6 +111,9 @@ def test_medusa_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
# Main model # Main model
"model_name": MAIN_MODEL, "model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.6
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
...@@ -169,6 +177,9 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -169,6 +177,9 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
# Main model # Main model
"model_name": MAIN_MODEL, "model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.6
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
...@@ -183,9 +194,9 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -183,9 +194,9 @@ def test_medusa_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
128, 128,
]) ])
@pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("batch_size", [1, 8])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 32]) @pytest.mark.parametrize("prefill_chunk_size", [-1, 8])
def test_medusa_e2e_greedy_correctness_cuda_graph( def test_medusa_e2e_greedy_correctness_cuda_graph(
vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs,
baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int, baseline_llm_kwargs, test_llm_kwargs, batch_size: int, output_len: int,
...@@ -207,7 +218,7 @@ def test_medusa_e2e_greedy_correctness_cuda_graph( ...@@ -207,7 +218,7 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [{
"block_size": 8, "block_size": 16,
# 2 for small prompt, 256//8 for generated. # 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8, "num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8, "max_model_len": (2 + 256 // 8) * 8,
...@@ -220,6 +231,9 @@ def test_medusa_e2e_greedy_correctness_cuda_graph( ...@@ -220,6 +231,9 @@ def test_medusa_e2e_greedy_correctness_cuda_graph(
# Main model # Main model
"model_name": MAIN_MODEL, "model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.6
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
...@@ -270,6 +284,9 @@ def test_medusa_e2e_greedy_correctness_with_preemption( ...@@ -270,6 +284,9 @@ def test_medusa_e2e_greedy_correctness_with_preemption(
# Main model # Main model
"model_name": MAIN_MODEL, "model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.6
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
...@@ -324,6 +341,9 @@ def test_medusa_different_k(vllm_runner, common_llm_kwargs, ...@@ -324,6 +341,9 @@ def test_medusa_different_k(vllm_runner, common_llm_kwargs,
# Main model # Main model
"model_name": MAIN_MODEL, "model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.6
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
...@@ -375,6 +395,9 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs, ...@@ -375,6 +395,9 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs,
# Main model # Main model
"model_name": MAIN_MODEL, "model_name": MAIN_MODEL,
# GPU memory utilization
"gpu_memory_utilization": 0.6
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
......
...@@ -32,6 +32,7 @@ from ..utils import maybe_enable_chunked_prefill ...@@ -32,6 +32,7 @@ from ..utils import maybe_enable_chunked_prefill
from .conftest import run_equality_correctness_test from .conftest import run_equality_correctness_test
from ...utils import models_path_prefix from ...utils import models_path_prefix
os.environ["LLAMA_NN"] = "0"
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
...@@ -41,6 +42,9 @@ from ...utils import models_path_prefix ...@@ -41,6 +42,9 @@ from ...utils import models_path_prefix
# Print spec metrics. # Print spec metrics.
"disable_log_stats": False, "disable_log_stats": False,
# GPU memory utilization
"gpu_memory_utilization": 0.6
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [ @pytest.mark.parametrize("per_test_common_llm_kwargs", [
{ {
...@@ -69,7 +73,7 @@ from ...utils import models_path_prefix ...@@ -69,7 +73,7 @@ from ...utils import models_path_prefix
@pytest.mark.parametrize("output_len", [ @pytest.mark.parametrize("output_len", [
256, 256,
]) ])
@pytest.mark.parametrize("batch_size", [1, 32]) @pytest.mark.parametrize("batch_size", [1, 4])
@pytest.mark.parametrize("prefill_chunk_size", [-1, 4]) @pytest.mark.parametrize("prefill_chunk_size", [-1, 4])
@pytest.mark.parametrize("seed", [1]) @pytest.mark.parametrize("seed", [1])
def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
...@@ -98,6 +102,9 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs, ...@@ -98,6 +102,9 @@ def test_ngram_e2e_greedy_correctness(vllm_runner, common_llm_kwargs,
# Print spec metrics. # Print spec metrics.
"disable_log_stats": False, "disable_log_stats": False,
# GPU memory utilization
"gpu_memory_utilization": 0.6
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [ @pytest.mark.parametrize("per_test_common_llm_kwargs", [
{ {
...@@ -154,13 +161,16 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs, ...@@ -154,13 +161,16 @@ def test_ngram_e2e_greedy_logprobs(vllm_runner, common_llm_kwargs,
@pytest.mark.parametrize( @pytest.mark.parametrize(
"common_llm_kwargs", "common_llm_kwargs",
[{ [{
"block_size": 8, "block_size": 16,
# 2 for small prompt, 256//8 for generated. # 2 for small prompt, 256//8 for generated.
"num_gpu_blocks_override": 2 + 256 // 8, "num_gpu_blocks_override": 2 + 256 // 8,
"max_model_len": (2 + 256 // 8) * 8, "max_model_len": (2 + 256 // 8) * 8,
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
# GPU memory utilization
"gpu_memory_utilization": 0.6
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [ @pytest.mark.parametrize("per_test_common_llm_kwargs", [
{ {
...@@ -222,6 +232,9 @@ def test_ngram_e2e_greedy_correctness_with_preemption( ...@@ -222,6 +232,9 @@ def test_ngram_e2e_greedy_correctness_with_preemption(
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
# GPU memory utilization
"gpu_memory_utilization": 0.6
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
...@@ -282,6 +295,9 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs, ...@@ -282,6 +295,9 @@ def test_ngram_different_k(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
# GPU memory utilization
"gpu_memory_utilization": 0.6
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
...@@ -338,6 +354,9 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs, ...@@ -338,6 +354,9 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs,
# Skip cuda graph recording for fast test. # Skip cuda graph recording for fast test.
"enforce_eager": True, "enforce_eager": True,
# GPU memory utilization
"gpu_memory_utilization": 0.6
}]) }])
@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}])
@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) @pytest.mark.parametrize("baseline_llm_kwargs", [{}])
......
...@@ -23,6 +23,7 @@ def test_initial_call_returns_none(): ...@@ -23,6 +23,7 @@ def test_initial_call_returns_none():
collector = AsyncMetricsCollector(spec_decode_sampler) collector = AsyncMetricsCollector(spec_decode_sampler)
collector.init_gpu_tensors(rank=0) collector.init_gpu_tensors(rank=0)
collector.init_tensors(rank=0)
maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5) maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert maybe_metrics is None assert maybe_metrics is None
...@@ -49,6 +50,7 @@ def test_second_call_returns_metrics(): ...@@ -49,6 +50,7 @@ def test_second_call_returns_metrics():
timer=timer, timer=timer,
collect_interval_s=collect_interval_s) collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0) collector.init_gpu_tensors(rank=0)
collector.init_tensors(rank=0)
_ = collector.maybe_collect_rejsample_metrics(k=5) _ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5) metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is not None assert metrics is not None
...@@ -69,6 +71,7 @@ def test_nonzero_rank_noop(rank): ...@@ -69,6 +71,7 @@ def test_nonzero_rank_noop(rank):
collector = AsyncMetricsCollector(spec_decode_sampler) collector = AsyncMetricsCollector(spec_decode_sampler)
collector.init_gpu_tensors(rank=rank) collector.init_gpu_tensors(rank=rank)
collector.init_tensors(rank=0)
_ = collector.maybe_collect_rejsample_metrics(k=5) _ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5) metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is None assert metrics is None
...@@ -97,6 +100,7 @@ def test_noop_until_time(): ...@@ -97,6 +100,7 @@ def test_noop_until_time():
timer=timer, timer=timer,
collect_interval_s=collect_interval_s) collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0) collector.init_gpu_tensors(rank=0)
collector.init_tensors(rank=0)
_ = collector.maybe_collect_rejsample_metrics(k=5) _ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5) metrics = collector.maybe_collect_rejsample_metrics(k=5)
...@@ -136,6 +140,7 @@ def test_timer_is_reset(): ...@@ -136,6 +140,7 @@ def test_timer_is_reset():
timer=timer, timer=timer,
collect_interval_s=collect_interval_s) collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0) collector.init_gpu_tensors(rank=0)
collector.init_tensors(rank=0)
_ = collector.maybe_collect_rejsample_metrics(k=5) _ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5) metrics = collector.maybe_collect_rejsample_metrics(k=5)
...@@ -186,6 +191,7 @@ def test_initial_metrics_has_correct_values(has_data: bool): ...@@ -186,6 +191,7 @@ def test_initial_metrics_has_correct_values(has_data: bool):
timer=timer, timer=timer,
collect_interval_s=collect_interval_s) collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0) collector.init_gpu_tensors(rank=0)
collector.init_tensors(rank=0)
_ = collector.maybe_collect_rejsample_metrics(k) _ = collector.maybe_collect_rejsample_metrics(k)
metrics = collector.maybe_collect_rejsample_metrics(k) metrics = collector.maybe_collect_rejsample_metrics(k)
......
...@@ -6,6 +6,7 @@ import os ...@@ -6,6 +6,7 @@ import os
import pytest import pytest
import torch import torch
from vllm.attention.selector import get_attn_backend
from vllm.sequence import ExecuteModelRequest from vllm.sequence import ExecuteModelRequest
from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer
from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores
...@@ -63,6 +64,11 @@ def test_scorer(model_name: str, batch_size: int, max_propose_len: int, ...@@ -63,6 +64,11 @@ def test_scorer(model_name: str, batch_size: int, max_propose_len: int,
num_gpu_blocks = 2048 // block_size num_gpu_blocks = 2048 // block_size
scorer_worker = create_worker(Worker, model_name, block_size, scorer_worker = create_worker(Worker, model_name, block_size,
num_gpu_blocks, seed) num_gpu_blocks, seed)
head_size = scorer_worker.model_config.get_head_size()
backend = get_attn_backend(head_size, torch.float16, torch.float16, 16, False)
if backend.get_name() != "FLASH_ATTN":
pytest.skip("MQAScorer is only available with flash attn backend.")
scorer_worker.model_runner.disable_logprobs = True # accessed by mqa_scorer scorer_worker.model_runner.disable_logprobs = True # accessed by mqa_scorer
scorer_worker.model_runner.sampler.include_gpu_probs_tensor = True scorer_worker.model_runner.sampler.include_gpu_probs_tensor = True
scorer_worker.model_runner.sampler.should_modify_greedy_probs_inplace = True scorer_worker.model_runner.sampler.should_modify_greedy_probs_inplace = True
......
...@@ -126,6 +126,17 @@ else: ...@@ -126,6 +126,17 @@ else:
{"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0,"num_warps": 4} {"BLOCK_SIZE_M": 32,"BLOCK_SIZE_N": 64,"BLOCK_SIZE_K": 64,"GROUP_SIZE_M": 1,"kpack": 2,"num_stages": 0,"num_warps": 4}
] ]
@triton.jit
def write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N, offs_token,
token_mask, BLOCK_SIZE_M, BLOCK_SIZE_N,
compute_type):
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
None, :]
c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
tl.store(c_ptrs, accumulator, mask=c_mask)
@triton.jit @triton.jit
def fused_moe_kernel_awq( def fused_moe_kernel_awq(
# Pointers to matrices # Pointers to matrices
...@@ -172,7 +183,8 @@ def fused_moe_kernel_awq( ...@@ -172,7 +183,8 @@ def fused_moe_kernel_awq(
compute_type: tl.constexpr, compute_type: tl.constexpr,
has_zp: tl.constexpr, has_zp: tl.constexpr,
use_int4_w4a16: tl.constexpr, use_int4_w4a16: tl.constexpr,
use_int8_w8a16: tl.constexpr): use_int8_w8a16: tl.constexpr,
enable_expert_parallel: int,):
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M) num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N) num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
...@@ -190,6 +202,17 @@ def fused_moe_kernel_awq( ...@@ -190,6 +202,17 @@ def fused_moe_kernel_awq(
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) # [block_m] offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) # [block_m]
token_mask = offs_token < num_valid_tokens token_mask = offs_token < num_valid_tokens
off_experts = tl.load(expert_ids_ptr + pid_m)
if enable_expert_parallel:
if off_experts == -1:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
# in the current expert parallel rank.
write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N,
offs_token, token_mask, BLOCK_SIZE_M,
BLOCK_SIZE_N, compute_type)
return
offs_bn = (pid_n * BLOCK_SIZE_N + offs_bn = (pid_n * BLOCK_SIZE_N +
tl.arange(0, BLOCK_SIZE_N)) % N # [block_n] tl.arange(0, BLOCK_SIZE_N)) % N # [block_n]
offs_k = tl.arange(0, BLOCK_SIZE_K) # 0, 1, 2, ...... , 127 # # [block_k] offs_k = tl.arange(0, BLOCK_SIZE_K) # 0, 1, 2, ...... , 127 # # [block_k]
...@@ -197,8 +220,6 @@ def fused_moe_kernel_awq( ...@@ -197,8 +220,6 @@ def fused_moe_kernel_awq(
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_k[None, :] * stride_ak) # [block_m, block_k] offs_k[None, :] * stride_ak) # [block_m, block_k]
off_experts = tl.load(expert_ids_ptr + pid_m)
if use_int4_w4a16: if use_int4_w4a16:
# [0, 1, 2, ...... , 126, 127] --> [0, 0, 1, 1 ...... , 63, 63] # [0, 1, 2, ...... , 126, 127] --> [0, 0, 1, 1 ...... , 63, 63]
# [128, 129, 130, ...... , 254, 255] --> [64, 64, 65, 65 ...... , 127, 127] # [128, 129, 130, ...... , 254, 255] --> [64, 64, 65, 65 ...... , 127, 127]
...@@ -329,7 +350,8 @@ def fused_moe_kernel_gptq_awq( ...@@ -329,7 +350,8 @@ def fused_moe_kernel_gptq_awq(
compute_type: tl.constexpr, compute_type: tl.constexpr,
has_zp: tl.constexpr, has_zp: tl.constexpr,
use_int4_w4a16: tl.constexpr, use_int4_w4a16: tl.constexpr,
use_int8_w8a16: tl.constexpr): use_int8_w8a16: tl.constexpr,
enable_expert_parallel: int,):
""" """
Implements the fused computation for a Mixture of Experts (MOE) using Implements the fused computation for a Mixture of Experts (MOE) using
token and expert matrices. token and expert matrices.
...@@ -383,14 +405,23 @@ def fused_moe_kernel_gptq_awq( ...@@ -383,14 +405,23 @@ def fused_moe_kernel_gptq_awq(
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens token_mask = offs_token < num_valid_tokens
off_experts = tl.load(expert_ids_ptr + pid_m)
if enable_expert_parallel:
if off_experts == -1:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
# in the current expert parallel rank.
write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N,
offs_token, token_mask, BLOCK_SIZE_M,
BLOCK_SIZE_N, compute_type)
return
offs_bn = (pid_n * BLOCK_SIZE_N + offs_bn = (pid_n * BLOCK_SIZE_N +
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K) offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_k[None, :] * stride_ak) offs_k[None, :] * stride_ak)
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
if use_int4_w4a16: if use_int4_w4a16:
b_ptrs = b_ptr + off_experts * stride_be + \ b_ptrs = b_ptr + off_experts * stride_be + \
(offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn (offs_k[:, None] // 2) * stride_bk + offs_bn[None, :] * stride_bn
...@@ -531,6 +562,7 @@ def fused_moe_kernel( ...@@ -531,6 +562,7 @@ def fused_moe_kernel(
use_int8_w8a8: tl.constexpr, use_int8_w8a8: tl.constexpr,
use_int8_w8a16: tl.constexpr, use_int8_w8a16: tl.constexpr,
per_channel_quant: tl.constexpr, per_channel_quant: tl.constexpr,
enable_expert_parallel: int,
): ):
""" """
Implements the fused computation for a Mixture of Experts (MOE) using Implements the fused computation for a Mixture of Experts (MOE) using
...@@ -598,13 +630,23 @@ def fused_moe_kernel( ...@@ -598,13 +630,23 @@ def fused_moe_kernel(
offs_token = tl.load(sorted_token_ids_ptr + offs_token_id) offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
token_mask = offs_token < num_valid_tokens token_mask = offs_token < num_valid_tokens
off_experts = tl.load(expert_ids_ptr + pid_m)
if enable_expert_parallel:
if off_experts == -1:
# -----------------------------------------------------------
# Write back zeros to the output when the expert is not
# in the current expert parallel rank.
write_zeros_to_output(c_ptr, stride_cm, stride_cn, pid_n, N,
offs_token, token_mask, BLOCK_SIZE_M,
BLOCK_SIZE_N, compute_type)
return
offs_bn = (pid_n * BLOCK_SIZE_N + offs_bn = (pid_n * BLOCK_SIZE_N +
tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
offs_k = tl.arange(0, BLOCK_SIZE_K) offs_k = tl.arange(0, BLOCK_SIZE_K)
a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am + a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
offs_k[None, :] * stride_ak) offs_k[None, :] * stride_ak)
off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk + b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
offs_bn[None, :] * stride_bn) offs_bn[None, :] * stride_bn)
if use_int8_w8a16: if use_int8_w8a16:
...@@ -719,7 +761,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -719,7 +761,8 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int4_w4a16: bool, use_int4_w4a16: bool,
per_channel_quant: bool, per_channel_quant: bool,
block_shape: Optional[List[int]] = None, block_shape: Optional[List[int]] = None,
use_nn_moe: Optional[bool]=False) -> None: use_nn_moe: Optional[bool]=False,
enable_expert_parallel: int=0,) -> None:
assert topk_weights is not None or not mul_routed_weight assert topk_weights is not None or not mul_routed_weight
assert topk_weights is None or topk_weights.stride(1) == 1 assert topk_weights is None or topk_weights.stride(1) == 1
assert sorted_token_ids.stride(0) == 1 assert sorted_token_ids.stride(0) == 1
...@@ -806,6 +849,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -806,6 +849,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
has_zp=B_zp is not None, has_zp=B_zp is not None,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
enable_expert_parallel=enable_expert_parallel,
**config, **config,
) )
else: else:
...@@ -844,6 +888,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -844,6 +888,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
has_zp=B_zp is not None, has_zp=B_zp is not None,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
enable_expert_parallel=enable_expert_parallel,
**config, **config,
) )
else: else:
...@@ -892,6 +937,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -892,6 +937,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int8_w8a8=use_int8_w8a8, use_int8_w8a8=use_int8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
enable_expert_parallel=enable_expert_parallel,
# BLOCK_SIZE_K=BLOCK_SIZE_K, # BLOCK_SIZE_K=BLOCK_SIZE_K,
**config, **config,
) )
...@@ -1704,6 +1750,7 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1704,6 +1750,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'], moe_align_block_size(curr_topk_ids, config['BLOCK_SIZE_M'],
global_num_experts, expert_map)) global_num_experts, expert_map))
enable_expert_parallel = (int)(expert_map is not None)
invoke_fused_moe_kernel(qcurr_hidden_states, invoke_fused_moe_kernel(qcurr_hidden_states,
w1, w1,
intermediate_cache1, intermediate_cache1,
...@@ -1725,7 +1772,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1725,7 +1772,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe,
enable_expert_parallel=enable_expert_parallel)
if activation == "silu": if activation == "silu":
torch.ops._C.silu_and_mul(intermediate_cache2, torch.ops._C.silu_and_mul(intermediate_cache2,
...@@ -1786,7 +1834,8 @@ def fused_experts_impl(hidden_states: torch.Tensor, ...@@ -1786,7 +1834,8 @@ def fused_experts_impl(hidden_states: torch.Tensor,
use_int4_w4a16=use_int4_w4a16, use_int4_w4a16=use_int4_w4a16,
per_channel_quant=per_channel_quant, per_channel_quant=per_channel_quant,
block_shape=block_shape, block_shape=block_shape,
use_nn_moe=use_nn_moe) use_nn_moe=use_nn_moe,
enable_expert_parallel=enable_expert_parallel)
ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape), ops.moe_sum(intermediate_cache3.view(*intermediate_cache3.shape),
out_hidden_states[begin_chunk_idx:end_chunk_idx]) out_hidden_states[begin_chunk_idx:end_chunk_idx])
......
...@@ -216,6 +216,11 @@ def moe_align_block_size( ...@@ -216,6 +216,11 @@ def moe_align_block_size(
max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size) max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
# Expert ids must be zeroed out to prevent index out of bounds error while # Expert ids must be zeroed out to prevent index out of bounds error while
# mapping global expert ids to local expert ids in expert parallelism. # mapping global expert ids to local expert ids in expert parallelism.
if expert_map is not None:
expert_ids = torch.zeros((max_num_m_blocks, ),
dtype=torch.int32,
device=topk_ids.device)
else:
expert_ids = torch.empty((max_num_m_blocks, ), expert_ids = torch.empty((max_num_m_blocks, ),
dtype=torch.int32, dtype=torch.int32,
device=topk_ids.device) device=topk_ids.device)
......
...@@ -28,11 +28,9 @@ class ResidualBlock(nn.Module): ...@@ -28,11 +28,9 @@ class ResidualBlock(nn.Module):
super().__init__() super().__init__()
self.layers = nn.ModuleList([ self.layers = nn.ModuleList([
# nn.Linear(hidden_size,
# hidden_size,
# bias=getattr(config, "medusa_fc_bias", False))
nn.Linear(hidden_size, nn.Linear(hidden_size,
hidden_size) hidden_size,
bias=getattr(config, "medusa_fc_bias", False))
for _ in range(num_layers) for _ in range(num_layers)
]) ])
self.act = nn.SiLU() self.act = nn.SiLU()
...@@ -117,7 +115,7 @@ class Medusa(nn.Module): ...@@ -117,7 +115,7 @@ class Medusa(nn.Module):
for hs, lm_head in zip(hidden_states, self.lm_heads): for hs, lm_head in zip(hidden_states, self.lm_heads):
#_logits = self.logits_processor(lm_head, hs, sampling_metadata) #_logits = self.logits_processor(lm_head, hs, sampling_metadata)
_logits = lm_head.linear_method.apply(lm_head, hs, bias=None) _logits = lm_head.quant_method.apply(lm_head, hs, bias=None)
_logits = tensor_model_parallel_all_gather(_logits) _logits = tensor_model_parallel_all_gather(_logits)
if _logits is None: if _logits is None:
......
...@@ -29,9 +29,7 @@ class MedusaWorker(NonLLMProposerWorkerBase, DelegateWorkerBase): ...@@ -29,9 +29,7 @@ class MedusaWorker(NonLLMProposerWorkerBase, DelegateWorkerBase):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
# skip lora config in medusa # skip lora config in medusa
kwargs_copy = kwargs.copy() DelegateWorkerBase.__init__(self, *args, **kwargs)
kwargs_copy['lora_config'] = None
DelegateWorkerBase.__init__(self, *args, **kwargs_copy)
# Lazy initialization list. # Lazy initialization list.
self._proposer: SpeculativeProposer self._proposer: SpeculativeProposer
self.tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1') self.tree_decoding = (os.environ.get('VLLM_TREE_DECODING') == '1')
......
...@@ -175,8 +175,6 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase): ...@@ -175,8 +175,6 @@ class SpecDecodeWorker(LoRANotSupportedWorkerBase):
draft_parallel_config: ParallelConfig = draft_worker_kwargs[ draft_parallel_config: ParallelConfig = draft_worker_kwargs[
'vllm_config'].parallel_config 'vllm_config'].parallel_config
if ngram_prompt_lookup_max > 0: if ngram_prompt_lookup_max > 0:
draft_parallel_config: ParallelConfig = draft_worker_kwargs[
'parallel_config']
assert draft_parallel_config.tensor_parallel_size == 1 assert draft_parallel_config.tensor_parallel_size == 1
draft_worker_kwargs[ draft_worker_kwargs[
"device_type"] = scorer_worker.device_config.device.type "device_type"] = scorer_worker.device_config.device.type
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment