"docs/vscode:/vscode.git/clone" did not exist on "8dbe0c527fa76cd908bb6287f9f501df44e04473"
Commit ca796e19 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.1' into v0.8.1-ori

parents e983c804 61c7a1b8
...@@ -76,21 +76,18 @@ async def generate(engine: AsyncLLM, ...@@ -76,21 +76,18 @@ async def generate(engine: AsyncLLM,
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.parametrize("engine_args_and_prompt", @pytest.mark.parametrize("engine_args,prompt",
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), [(TEXT_ENGINE_ARGS, TEXT_PROMPT),
(VISION_ENGINE_ARGS, VISION_PROMPT)]) (VISION_ENGINE_ARGS, VISION_PROMPT)])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_load( async def test_load(monkeypatch: pytest.MonkeyPatch,
monkeypatch: pytest.MonkeyPatch, output_kind: RequestOutputKind,
output_kind: RequestOutputKind, engine_args: AsyncEngineArgs, prompt: PromptType):
engine_args_and_prompt: tuple[AsyncEngineArgs, PromptType],
):
# TODO(rickyx): Remove monkeypatch once we have a better way to test V1 # TODO(rickyx): Remove monkeypatch once we have a better way to test V1
# so that in the future when we switch, we don't have to change all the # so that in the future when we switch, we don't have to change all the
# tests. # tests.
with monkeypatch.context() as m, ExitStack() as after: with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
engine_args, prompt = engine_args_and_prompt
engine = AsyncLLM.from_engine_args(engine_args) engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown) after.callback(engine.shutdown)
...@@ -124,18 +121,16 @@ async def test_load( ...@@ -124,18 +121,16 @@ async def test_load(
@pytest.mark.parametrize( @pytest.mark.parametrize(
"output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY]) "output_kind", [RequestOutputKind.DELTA, RequestOutputKind.FINAL_ONLY])
@pytest.mark.parametrize("engine_args_and_prompt", @pytest.mark.parametrize("engine_args,prompt",
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), [(TEXT_ENGINE_ARGS, TEXT_PROMPT),
(VISION_ENGINE_ARGS, VISION_PROMPT)]) (VISION_ENGINE_ARGS, VISION_PROMPT)])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_abort(monkeypatch: pytest.MonkeyPatch, async def test_abort(monkeypatch: pytest.MonkeyPatch,
output_kind: RequestOutputKind, output_kind: RequestOutputKind,
engine_args_and_prompt: tuple[AsyncEngineArgs, engine_args: AsyncEngineArgs, prompt: PromptType):
PromptType]):
with monkeypatch.context() as m, ExitStack() as after: with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
engine_args, prompt = engine_args_and_prompt
engine = AsyncLLM.from_engine_args(engine_args) engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown) after.callback(engine.shutdown)
...@@ -193,17 +188,15 @@ async def test_abort(monkeypatch: pytest.MonkeyPatch, ...@@ -193,17 +188,15 @@ async def test_abort(monkeypatch: pytest.MonkeyPatch,
@pytest.mark.parametrize("n", [1, 3]) @pytest.mark.parametrize("n", [1, 3])
@pytest.mark.parametrize("engine_args_and_prompt", @pytest.mark.parametrize("engine_args,prompt",
[(TEXT_ENGINE_ARGS, TEXT_PROMPT), [(TEXT_ENGINE_ARGS, TEXT_PROMPT),
(VISION_ENGINE_ARGS, VISION_PROMPT)]) (VISION_ENGINE_ARGS, VISION_PROMPT)])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_finished_flag(monkeypatch, n: int, async def test_finished_flag(monkeypatch: pytest.MonkeyPatch, n: int,
engine_args_and_prompt: tuple[AsyncEngineArgs, engine_args: AsyncEngineArgs, prompt: PromptType):
PromptType]):
with monkeypatch.context() as m, ExitStack() as after: with monkeypatch.context() as m, ExitStack() as after:
m.setenv("VLLM_USE_V1", "1") m.setenv("VLLM_USE_V1", "1")
engine_args, prompt = engine_args_and_prompt
engine = AsyncLLM.from_engine_args(engine_args) engine = AsyncLLM.from_engine_args(engine_args)
after.callback(engine.shutdown) after.callback(engine.shutdown)
......
...@@ -50,7 +50,7 @@ def _get_test_sampling_params( ...@@ -50,7 +50,7 @@ def _get_test_sampling_params(
"""Generate random sampling params for a batch.""" """Generate random sampling params for a batch."""
def get_mostly_n_gt1() -> int: def get_mostly_n_gt1() -> int:
"""Mostly n \in [2,20], ~1/3 n=1""" r"""Mostly n \in [2,20], ~1/3 n=1"""
x = random.randint(0, 28) x = random.randint(0, 28)
if x < 10: if x < 10:
return 1 return 1
......
...@@ -6,20 +6,23 @@ import torch ...@@ -6,20 +6,23 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.sample.rejection_sampler import INVALID_TOKEN_ID, RejectionSampler from vllm.v1.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID,
RejectionSampler)
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
DEVICE = "cpu" DEVICE = "cuda"
@pytest.fixture @pytest.fixture
def sampler(): def rejection_sampler():
return RejectionSampler() return RejectionSampler()
def create_logits_tensor(token_ids: list[list[int]], def create_logits_tensor(output_token_ids: list[list[int]],
vocab_size: int = 100) -> torch.Tensor: vocab_size: int = 100) -> torch.Tensor:
"""Helper function to create logits tensor that """Helper function to create logits tensor that
will produce desired token ids on argmax""" will produce desired token ids on argmax"""
token_ids = [tokens[:-1] for tokens in output_token_ids]
num_total_tokens = sum(len(tokens) for tokens in token_ids) num_total_tokens = sum(len(tokens) for tokens in token_ids)
logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE) logits = torch.full((num_total_tokens, vocab_size), -100.0, device=DEVICE)
start_loc = 0 start_loc = 0
...@@ -31,15 +34,22 @@ def create_logits_tensor(token_ids: list[list[int]], ...@@ -31,15 +34,22 @@ def create_logits_tensor(token_ids: list[list[int]],
def create_sampling_metadata( def create_sampling_metadata(
all_greedy: bool, all_greedy: bool,
generators: Optional[dict[int, Any]] = None) -> SamplingMetadata: temperature: Optional[torch.Tensor] = None,
generators: Optional[dict[int, Any]] = None,
) -> SamplingMetadata:
"""Create a v1 sampling metadata object with all_greedy set """Create a v1 sampling metadata object with all_greedy set
to the given value. Either all greedy or all random sampling to the given value. Either all greedy or all random sampling
is used. is used.
""" """
generators = generators or {} generators = generators or {}
if all_greedy:
temperature = None
else:
assert temperature is not None
return SamplingMetadata( return SamplingMetadata(
temperature=torch.tensor([]), temperature=temperature,
all_greedy=all_greedy, all_greedy=all_greedy,
all_random=not all_greedy, all_random=not all_greedy,
top_p=None, top_p=None,
...@@ -61,7 +71,7 @@ def create_sampling_metadata( ...@@ -61,7 +71,7 @@ def create_sampling_metadata(
########################### Tests for Greedy Sampling ################### ########################### Tests for Greedy Sampling ###################
def test_perfect_match(sampler): def test_perfect_match(rejection_sampler):
"""Test when output tokens perfectly match speculated tokens""" """Test when output tokens perfectly match speculated tokens"""
spec_tokens = [[1, 2, 3]] spec_tokens = [[1, 2, 3]]
output_tokens = [[1, 2, 3, 4]] # 4 is the bonus token output_tokens = [[1, 2, 3, 4]] # 4 is the bonus token
...@@ -70,15 +80,23 @@ def test_perfect_match(sampler): ...@@ -70,15 +80,23 @@ def test_perfect_match(sampler):
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device) device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) device=logits.device)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor([[1, 2, 3, 4]], expected = torch.tensor([[1, 2, 3, 4]],
dtype=torch.int, dtype=torch.int,
device=logits.device) device=logits.device)
assert torch.equal(output, expected) assert torch.equal(output, expected)
def test_early_mismatch(sampler): def test_early_mismatch(rejection_sampler):
"""Test when there's an early mismatch in tokens""" """Test when there's an early mismatch in tokens"""
spec_tokens = [[1, 2, 3]] spec_tokens = [[1, 2, 3]]
output_tokens = [[1, 5, 3, 4]] # Mismatch at position 1 output_tokens = [[1, 5, 3, 4]] # Mismatch at position 1
...@@ -87,15 +105,25 @@ def test_early_mismatch(sampler): ...@@ -87,15 +105,25 @@ def test_early_mismatch(sampler):
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device) device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) device=logits.device)
expected = torch.tensor([[1, 5, INVALID_TOKEN_ID, INVALID_TOKEN_ID]],
dtype=torch.int, output = rejection_sampler(
device=logits.device) spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor(
[[1, 5, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]],
dtype=torch.int,
device=logits.device,
)
assert torch.equal(output, expected) assert torch.equal(output, expected)
def test_multiple_sequences(sampler): def test_multiple_sequences(rejection_sampler):
"""Test handling multiple sequences of speculated tokens""" """Test handling multiple sequences of speculated tokens"""
spec_tokens = [[1, 2], [3]] spec_tokens = [[1, 2], [3]]
output_tokens = [[1, 2, 5], [3, output_tokens = [[1, 2, 5], [3,
...@@ -105,15 +133,23 @@ def test_multiple_sequences(sampler): ...@@ -105,15 +133,23 @@ def test_multiple_sequences(sampler):
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor( bonus_token_tensor = torch.tensor(
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device) [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) device=logits.device)
expected = torch.tensor([[1, 2, 5], [3, 4, INVALID_TOKEN_ID]],
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor([[1, 2, 5], [3, 4, PLACEHOLDER_TOKEN_ID]],
dtype=torch.int, dtype=torch.int,
device=logits.device) device=logits.device)
assert torch.equal(output, expected) assert torch.equal(output, expected)
def test_single_token_sequence(sampler): def test_single_token_sequence(rejection_sampler):
"""Test handling sequences with single token""" """Test handling sequences with single token"""
spec_tokens = [[1]] spec_tokens = [[1]]
output_tokens = [[1, 2]] # Single token with bonus token 2 output_tokens = [[1, 2]] # Single token with bonus token 2
...@@ -122,13 +158,21 @@ def test_single_token_sequence(sampler): ...@@ -122,13 +158,21 @@ def test_single_token_sequence(sampler):
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device) device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) device=logits.device)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device) expected = torch.tensor([[1, 2]], dtype=torch.int, device=logits.device)
assert torch.equal(output, expected) assert torch.equal(output, expected)
def test_empty_sequence(sampler): def test_empty_sequence(rejection_sampler):
"""Test handling empty sequence of speculated tokens""" """Test handling empty sequence of speculated tokens"""
spec_tokens: list[list[int]] = [[]] spec_tokens: list[list[int]] = [[]]
output_tokens = [[5]] # Just the bonus token output_tokens = [[5]] # Just the bonus token
...@@ -137,13 +181,21 @@ def test_empty_sequence(sampler): ...@@ -137,13 +181,21 @@ def test_empty_sequence(sampler):
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([output_tokens[0][-1]], bonus_token_tensor = torch.tensor([output_tokens[0][-1]],
device=logits.device) device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) device=logits.device)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor([[5]], dtype=torch.int, device=logits.device) expected = torch.tensor([[5]], dtype=torch.int, device=logits.device)
assert torch.equal(output, expected) assert torch.equal(output, expected)
def test_multiple_mismatches(sampler): def test_multiple_mismatches(rejection_sampler):
"""Test handling multiple sequences with mismatches""" """Test handling multiple sequences with mismatches"""
spec_tokens = [[1, 2, 3], [4, 5, 6]] spec_tokens = [[1, 2, 3], [4, 5, 6]]
output_tokens = [[1, 2, 7, 6], [4, 8, 6, output_tokens = [[1, 2, 7, 6], [4, 8, 6,
...@@ -153,12 +205,22 @@ def test_multiple_mismatches(sampler): ...@@ -153,12 +205,22 @@ def test_multiple_mismatches(sampler):
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor( bonus_token_tensor = torch.tensor(
[output_tokens[0][-1], output_tokens[1][-1]], device=logits.device) [output_tokens[0][-1], output_tokens[1][-1]], device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) device=logits.device)
expected = torch.tensor([[1, 2, 7, INVALID_TOKEN_ID],
[4, 8, INVALID_TOKEN_ID, INVALID_TOKEN_ID]], output = rejection_sampler(
dtype=torch.int, spec_decode_metadata,
device=logits.device) draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected = torch.tensor(
[[1, 2, 7, PLACEHOLDER_TOKEN_ID],
[4, 8, PLACEHOLDER_TOKEN_ID, PLACEHOLDER_TOKEN_ID]],
dtype=torch.int,
device=logits.device,
)
assert torch.equal(output, expected) assert torch.equal(output, expected)
...@@ -166,18 +228,27 @@ def test_multiple_mismatches(sampler): ...@@ -166,18 +228,27 @@ def test_multiple_mismatches(sampler):
"spec_tokens,output_tokens,expected", "spec_tokens,output_tokens,expected",
[ [
([[1, 2]], [[1, 2, 3]], [[1, 2, 3]]), # Perfect match with bonus ([[1, 2]], [[1, 2, 3]], [[1, 2, 3]]), # Perfect match with bonus
([[1]], [[2, 3]], [[2, INVALID_TOKEN_ID]]), # First mismatch ([[1]], [[2, 3]], [[2, PLACEHOLDER_TOKEN_ID]]), # First mismatch
([[1, 2], [3, 4]], [[1, 5, 6], [3, 4, 7]], ([[1, 2], [3, 4]], [[1, 5, 6], [3, 4, 7]],
[[1, 5, INVALID_TOKEN_ID], [3, 4, 7]]), # Mixed matches [[1, 5, PLACEHOLDER_TOKEN_ID], [3, 4, 7]]), # Mixed matches
]) ])
def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected): def test_parametrized_cases(rejection_sampler, spec_tokens, output_tokens,
expected):
"""Parametrized test for various matching scenarios""" """Parametrized test for various matching scenarios"""
metadata = create_sampling_metadata(all_greedy=True) metadata = create_sampling_metadata(all_greedy=True)
logits = create_logits_tensor(output_tokens) logits = create_logits_tensor(output_tokens)
bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens], bonus_token_tensor = torch.tensor([tokens[-1] for tokens in output_tokens],
device=logits.device) device=logits.device)
spec_decode_metadata = SpecDecodeMetadata.make_dummy(spec_tokens,
output = sampler(spec_tokens, None, bonus_token_tensor, logits, metadata) device=logits.device)
output = rejection_sampler(
spec_decode_metadata,
draft_probs=None,
target_logits=logits,
bonus_token_ids=bonus_token_tensor,
sampling_metadata=metadata,
)
expected_tensor = torch.tensor(expected, expected_tensor = torch.tensor(expected,
dtype=torch.int, dtype=torch.int,
device=logits.device) device=logits.device)
...@@ -190,21 +261,31 @@ def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected): ...@@ -190,21 +261,31 @@ def test_parametrized_cases(sampler, spec_tokens, output_tokens, expected):
@pytest.mark.parametrize("batch_size", [1, 4, 8]) @pytest.mark.parametrize("batch_size", [1, 4, 8])
@pytest.mark.parametrize("frac_seeded", [0.0, 0.5]) @pytest.mark.parametrize("frac_seeded", [0.0, 0.5])
@pytest.mark.parametrize("n_rep", [20]) @pytest.mark.parametrize("n_rep", [20])
def test_deterministic_when_seeded(sampler, k: int, vocab_size: int, def test_deterministic_when_seeded(
batch_size: int, frac_seeded: float, rejection_sampler,
n_rep: int): k: int,
draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32) vocab_size: int,
target_probs = torch.rand(batch_size * (k + 1), batch_size: int,
vocab_size, frac_seeded: float,
dtype=torch.float32) n_rep: int,
):
num_tokens = batch_size * k
draft_probs = torch.rand(num_tokens,
vocab_size,
dtype=torch.float32,
device=DEVICE)
draft_probs = F.softmax(draft_probs, dim=-1)
target_logits = torch.rand_like(draft_probs)
bonus_token_ids = torch.randint(low=0, bonus_token_ids = torch.randint(low=0,
high=vocab_size, high=vocab_size,
size=(batch_size, 1), size=(batch_size, 1),
dtype=torch.int64) dtype=torch.int64,
device=DEVICE)
draft_token_ids = torch.randint(low=0, draft_token_ids = torch.randint(low=0,
high=vocab_size, high=vocab_size,
size=(batch_size, k), size=(batch_size, k),
dtype=torch.int64) dtype=torch.int64,
device=DEVICE)
seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded seeded_mask = torch.rand(batch_size, dtype=torch.float32) <= frac_seeded
...@@ -215,10 +296,21 @@ def test_deterministic_when_seeded(sampler, k: int, vocab_size: int, ...@@ -215,10 +296,21 @@ def test_deterministic_when_seeded(sampler, k: int, vocab_size: int,
for i in range(batch_size) if seeded_mask[i] for i in range(batch_size) if seeded_mask[i]
} }
temperature = torch.ones(batch_size,
dtype=torch.float32,
device=DEVICE)
sampling_metadata = create_sampling_metadata(all_greedy=False, sampling_metadata = create_sampling_metadata(all_greedy=False,
temperature=temperature,
generators=seeded_seqs) generators=seeded_seqs)
rep_result = sampler(draft_token_ids.tolist(), draft_probs, spec_decode_metadata = SpecDecodeMetadata.make_dummy(
bonus_token_ids, target_probs, sampling_metadata) draft_token_ids.tolist(), device=DEVICE)
rep_result = rejection_sampler(
spec_decode_metadata,
draft_probs=draft_probs,
target_logits=target_logits,
bonus_token_ids=bonus_token_ids,
sampling_metadata=sampling_metadata,
)
results.append(rep_result) results.append(rep_result)
...@@ -257,10 +349,10 @@ def test_rejection_sampling_approximates_target_distribution(): ...@@ -257,10 +349,10 @@ def test_rejection_sampling_approximates_target_distribution():
num_reference_probs = 100 num_reference_probs = 100
# Prepare draft, target, and reference probability distributions # Prepare draft, target, and reference probability distributions
draft_probs, target_probs = (F.softmax( draft_probs = F.softmax(torch.rand(vocab_size, dtype=torch.float32),
torch.rand(vocab_size, dtype=torch.float32), dim=-1)
dim=-1, target_logits = torch.rand(vocab_size, dtype=torch.float32)
) for _ in range(2)) target_probs = F.softmax(target_logits, dim=-1)
reference_probs = F.softmax( reference_probs = F.softmax(
torch.rand(num_reference_probs, vocab_size, dtype=torch.float32), torch.rand(num_reference_probs, vocab_size, dtype=torch.float32),
dim=-1, dim=-1,
...@@ -273,7 +365,7 @@ def test_rejection_sampling_approximates_target_distribution(): ...@@ -273,7 +365,7 @@ def test_rejection_sampling_approximates_target_distribution():
for num_samples in sample_sizes: for num_samples in sample_sizes:
# Sample using rejection sampling. # Sample using rejection sampling.
rej_sample_probs = estimate_rejection_sampling_pdf( rej_sample_probs = estimate_rejection_sampling_pdf(
draft_probs, target_probs, k, vocab_size, num_samples) draft_probs, target_logits, k, vocab_size, num_samples)
rej_sample_probs = rej_sample_probs.to(DEVICE) rej_sample_probs = rej_sample_probs.to(DEVICE)
# Average distance from reference probs. # Average distance from reference probs.
...@@ -313,7 +405,7 @@ def get_ratio_first_to_last(elements: list[float]) -> float: ...@@ -313,7 +405,7 @@ def get_ratio_first_to_last(elements: list[float]) -> float:
def estimate_rejection_sampling_pdf( def estimate_rejection_sampling_pdf(
draft_probs: torch.Tensor, draft_probs: torch.Tensor,
target_probs: torch.Tensor, target_logits: torch.Tensor,
k: int, k: int,
vocab_size: int, vocab_size: int,
num_samples: int, num_samples: int,
...@@ -323,35 +415,44 @@ def estimate_rejection_sampling_pdf( ...@@ -323,35 +415,44 @@ def estimate_rejection_sampling_pdf(
Args: Args:
draft_probs: Draft probability distribution. draft_probs: Draft probability distribution.
target_probs: Target probability distribution. target_logits: Target logits.
num_samples: Number of samples to draw. num_samples: Number of samples to draw.
Returns: Returns:
Estimated probability distribution of the output tokens. Estimated probability distribution of the output tokens.
""" """
sampler = RejectionSampler() rejection_sampler = RejectionSampler()
# Repeat draft probs num_samples times. num_tokens = num_samples * k
# Repeat draft probs num_samples * k times.
draft_probs = draft_probs.reshape(1, 1, draft_probs = draft_probs.reshape(1, 1,
vocab_size).repeat(num_samples, k, 1) vocab_size).repeat(num_samples, k, 1)
# Repeat target probs num_samples * (k + 1) times. # Repeat target probs num_tokens times.
target_probs = target_probs.reshape(1, 1, vocab_size).repeat( target_logits = target_logits.reshape(1, vocab_size).repeat(num_tokens, 1)
num_samples, k + 1, 1).reshape(num_samples * (k + 1), vocab_size)
# Randomly sample draft token ids from draft probs. # Randomly sample draft token ids from draft probs.
draft_token_ids = torch.multinomial(draft_probs[:, 0, :], draft_token_ids = torch.multinomial(draft_probs[:, 0, :],
num_samples=k, num_samples=k,
replacement=True).reshape( replacement=True).reshape(
num_samples, k) num_samples, k)
draft_probs = draft_probs.view(num_tokens, vocab_size)
# Bonus tokens not used but required. # Bonus tokens not used but required.
bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64, bonus_token_ids = torch.zeros((1, 1), dtype=torch.int64,
device=DEVICE).repeat(num_samples, 1) device=DEVICE).repeat(num_samples, 1)
sampling_metadata = create_sampling_metadata(all_greedy=False) temperature = torch.ones(num_samples, dtype=torch.float32, device=DEVICE)
output_token_ids = sampler(draft_token_ids.tolist(), draft_probs, sampling_metadata = create_sampling_metadata(all_greedy=False,
bonus_token_ids, target_probs, temperature=temperature)
sampling_metadata) spec_decode_metadata = SpecDecodeMetadata.make_dummy(
draft_token_ids.tolist(), device=bonus_token_ids.device)
output_token_ids = rejection_sampler(
spec_decode_metadata,
draft_probs=draft_probs,
target_logits=target_logits,
bonus_token_ids=bonus_token_ids,
sampling_metadata=sampling_metadata,
)
output_token_ids = output_token_ids[:, :-1].flatten() output_token_ids = output_token_ids[:, :-1].flatten()
hist = torch.histogram(output_token_ids.to(dtype=torch.float, hist = torch.histogram(output_token_ids.to(dtype=torch.float,
......
...@@ -15,9 +15,10 @@ if TYPE_CHECKING: ...@@ -15,9 +15,10 @@ if TYPE_CHECKING:
from tests.conftest import VllmRunner from tests.conftest import VllmRunner
MODELS = [ MODELS = [
"Qwen/Qwen2.5-1.5B-Instruct",
# TODO: Enable this models with v6e
# "Qwen/Qwen2-7B-Instruct", # "Qwen/Qwen2-7B-Instruct",
"meta-llama/Llama-3.1-8B", # "meta-llama/Llama-3.1-8B",
# TODO: Add models here as necessary
] ]
TENSOR_PARALLEL_SIZES = [1] TENSOR_PARALLEL_SIZES = [1]
......
...@@ -347,7 +347,7 @@ class ModelConfig: ...@@ -347,7 +347,7 @@ class ModelConfig:
self.encoder_config = self._get_encoder_config() self.encoder_config = self._get_encoder_config()
self.hf_image_processor_config = get_hf_image_processor_config( self.hf_image_processor_config = get_hf_image_processor_config(
self.model, revision) self.model, revision)
self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype) self.dtype = _get_and_verify_dtype(self.hf_config, dtype)
self.use_async_output_proc = use_async_output_proc self.use_async_output_proc = use_async_output_proc
self.mm_processor_kwargs = mm_processor_kwargs self.mm_processor_kwargs = mm_processor_kwargs
self.disable_mm_preprocessor_cache = disable_mm_preprocessor_cache self.disable_mm_preprocessor_cache = disable_mm_preprocessor_cache
...@@ -2526,6 +2526,14 @@ def _get_and_verify_dtype( ...@@ -2526,6 +2526,14 @@ def _get_and_verify_dtype(
# NOTE: getattr(config, "torch_dtype", torch.float32) is not correct # NOTE: getattr(config, "torch_dtype", torch.float32) is not correct
# because config.torch_dtype can be None. # because config.torch_dtype can be None.
config_dtype = getattr(config, "torch_dtype", None) config_dtype = getattr(config, "torch_dtype", None)
# Fallbacks for multi-modal models if the root config
# does not define torch_dtype
if config_dtype is None and hasattr(config, "text_config"):
config_dtype = getattr(config.text_config, "torch_dtype", None)
if config_dtype is None and hasattr(config, "vision_config"):
config_dtype = getattr(config.vision_config, "torch_dtype", None)
if config_dtype is None: if config_dtype is None:
config_dtype = torch.float32 config_dtype = torch.float32
...@@ -2533,16 +2541,8 @@ def _get_and_verify_dtype( ...@@ -2533,16 +2541,8 @@ def _get_and_verify_dtype(
dtype = dtype.lower() dtype = dtype.lower()
if dtype == "auto": if dtype == "auto":
if config_dtype == torch.float32: if config_dtype == torch.float32:
if config.model_type in ("gemma2", "gemma3", "gemma3_text"): # Following common practice, we use float16 for float32 models
logger.info( torch_dtype = torch.float16
"For Gemma 2 and 3, we downcast float32 to bfloat16 "
"instead of float16 by default. Please specify `dtype` "
"if you want to use float16.")
torch_dtype = torch.bfloat16
else:
# Following the common practice, we use float16 for float32
# models.
torch_dtype = torch.float16
else: else:
torch_dtype = config_dtype torch_dtype = config_dtype
......
...@@ -1469,8 +1469,12 @@ class EngineArgs: ...@@ -1469,8 +1469,12 @@ class EngineArgs:
return False return False
# Need at least Ampere for now (FA support required). # Need at least Ampere for now (FA support required).
# Skip this check if we are running on a non-GPU platform,
# or if the device capability is not available
# (e.g. in a Ray actor without GPUs).
from vllm.platforms import current_platform from vllm.platforms import current_platform
if (current_platform.is_cuda() if (current_platform.is_cuda()
and current_platform.get_device_capability()
and current_platform.get_device_capability().major < 8): and current_platform.get_device_capability().major < 8):
_raise_or_fallback(feature_name="Compute Capability < 8.0", _raise_or_fallback(feature_name="Compute Capability < 8.0",
recommend_to_remove=False) recommend_to_remove=False)
...@@ -1574,6 +1578,13 @@ class EngineArgs: ...@@ -1574,6 +1578,13 @@ class EngineArgs:
_raise_or_fallback(feature_name=name, recommend_to_remove=True) _raise_or_fallback(feature_name=name, recommend_to_remove=True)
return False return False
# No support for device type other than CUDA, AMD (experiemntal) or
# TPU (experimental) so far.
if not (current_platform.is_cuda_alike() or current_platform.is_tpu()):
_raise_or_fallback(
feature_name=f"device type={current_platform.device_type}",
recommend_to_remove=False)
return False
############################################################# #############################################################
# Experimental Features - allow users to opt in. # Experimental Features - allow users to opt in.
......
...@@ -548,7 +548,7 @@ class ChatCompletionRequest(OpenAIBaseModel): ...@@ -548,7 +548,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
if top_logprobs < 0: if top_logprobs < 0:
raise ValueError("`top_logprobs` must be a positive value.") raise ValueError("`top_logprobs` must be a positive value.")
if not data.get("logprobs"): if top_logprobs > 0 and not data.get("logprobs"):
raise ValueError( raise ValueError(
"when using `top_logprobs`, `logprobs` must be set to true." "when using `top_logprobs`, `logprobs` must be set to true."
) )
......
...@@ -35,7 +35,6 @@ if TYPE_CHECKING: ...@@ -35,7 +35,6 @@ if TYPE_CHECKING:
VLLM_TRACE_FUNCTION: int = 0 VLLM_TRACE_FUNCTION: int = 0
VLLM_ATTENTION_BACKEND: Optional[str] = None VLLM_ATTENTION_BACKEND: Optional[str] = None
VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None VLLM_USE_FLASHINFER_SAMPLER: Optional[bool] = None
VLLM_USE_FLASHINFER_REJECTION_SAMPLER: bool = False
VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False VLLM_FLASHINFER_FORCE_TENSOR_CORES: bool = False
VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_PP_LAYER_PARTITION: Optional[str] = None
VLLM_CPU_KVCACHE_SPACE: int = 0 VLLM_CPU_KVCACHE_SPACE: int = 0
......
...@@ -16,12 +16,8 @@ import torch ...@@ -16,12 +16,8 @@ import torch
from vllm.config import VllmConfig from vllm.config import VllmConfig
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.triton_utils.importing import HAS_TRITON
from vllm.utils import _check_multiproc_method, get_mp_context, run_method from vllm.utils import _check_multiproc_method, get_mp_context, run_method
if HAS_TRITON:
from vllm.triton_utils import maybe_set_triton_cache_manager
logger = init_logger(__name__) logger = init_logger(__name__)
T = TypeVar('T') T = TypeVar('T')
...@@ -314,7 +310,3 @@ def set_multiprocessing_worker_envs(parallel_config): ...@@ -314,7 +310,3 @@ def set_multiprocessing_worker_envs(parallel_config):
current_parallelism, default_omp_num_threads) current_parallelism, default_omp_num_threads)
os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads) os.environ["OMP_NUM_THREADS"] = str(default_omp_num_threads)
torch.set_num_threads(default_omp_num_threads) torch.set_num_threads(default_omp_num_threads)
# workaround for https://github.com/vllm-project/vllm/issues/6103
if HAS_TRITON and parallel_config.world_size > 1:
maybe_set_triton_cache_manager()
...@@ -30,6 +30,7 @@ from vllm.lora.utils import (from_layer, from_layer_logits_processor, ...@@ -30,6 +30,7 @@ from vllm.lora.utils import (from_layer, from_layer_logits_processor,
is_regex_target_modules, is_regex_target_modules,
parse_fine_tuned_lora_name, replace_submodule) parse_fine_tuned_lora_name, replace_submodule)
from vllm.model_executor.models import SupportsLoRA, supports_multimodal from vllm.model_executor.models import SupportsLoRA, supports_multimodal
from vllm.model_executor.models.interfaces import is_pooling_model
from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.module_mapping import MultiModelKeys
from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper from vllm.model_executor.models.utils import PPMissingLayer, WeightsMapper
from vllm.utils import is_pin_memory_available from vllm.utils import is_pin_memory_available
...@@ -104,6 +105,9 @@ class LoRAModel(AdapterModel): ...@@ -104,6 +105,9 @@ class LoRAModel(AdapterModel):
"""Get LoRA for a given module by name""" """Get LoRA for a given module by name"""
return self.loras.get(module_name, None) return self.loras.get(module_name, None)
def check_lora_name(self, lora_name: str) -> bool:
return lora_name in self.loras
# (yard1): TODO see if we can derive target_embedding_padding automatically # (yard1): TODO see if we can derive target_embedding_padding automatically
@classmethod @classmethod
def from_lora_tensors( def from_lora_tensors(
...@@ -335,6 +339,7 @@ class LoRAModelManager(AdapterModelManager): ...@@ -335,6 +339,7 @@ class LoRAModelManager(AdapterModelManager):
# Used for long context lora. # Used for long context lora.
self.scaling_factor_to_offset: Dict[float, int] = {} self.scaling_factor_to_offset: Dict[float, int] = {}
super().__init__(model) super().__init__(model)
self.supported_lora_modules = get_supported_lora_modules(self.model) self.supported_lora_modules = get_supported_lora_modules(self.model)
assert self.supported_lora_modules, "No supported LoRA modules found in" assert self.supported_lora_modules, "No supported LoRA modules found in"
f"{self.model.__class__.__name__}." f"{self.model.__class__.__name__}."
...@@ -350,6 +355,7 @@ class LoRAModelManager(AdapterModelManager): ...@@ -350,6 +355,7 @@ class LoRAModelManager(AdapterModelManager):
# In case the model only supports LoRA for # In case the model only supports LoRA for
# text modules (e.g. ChatGLM) # text modules (e.g. ChatGLM)
and hasattr(self.model, "get_mm_mapping")) and hasattr(self.model, "get_mm_mapping"))
self.is_pooling_model = is_pooling_model(self.model)
self.packed_modules: Dict[str, List[str]] = {} self.packed_modules: Dict[str, List[str]] = {}
self.modules: Dict[str, BaseLayerWithLoRA] = {} self.modules: Dict[str, BaseLayerWithLoRA] = {}
# Dict instead of a Set for compatibility with LRUCache. # Dict instead of a Set for compatibility with LRUCache.
...@@ -389,7 +395,7 @@ class LoRAModelManager(AdapterModelManager): ...@@ -389,7 +395,7 @@ class LoRAModelManager(AdapterModelManager):
lora_model.id, index) lora_model.id, index)
self.lora_index_to_id[index] = lora_model.id self.lora_index_to_id[index] = lora_model.id
for module_name, module in self.modules.items(): for module_name, module in self.modules.items():
module_lora = lora_model.get_lora(module_name) module_lora = self._get_lora_layer_weights(lora_model, module_name)
if module_lora: if module_lora:
module_lora.optimize() module_lora.optimize()
# Bias is not explicitly enabled with the flag enable_lora_bias. # Bias is not explicitly enabled with the flag enable_lora_bias.
...@@ -626,7 +632,7 @@ class LoRAModelManager(AdapterModelManager): ...@@ -626,7 +632,7 @@ class LoRAModelManager(AdapterModelManager):
replaced_module: Set[str] = set() replaced_module: Set[str] = set()
has_replacement = False has_replacement = False
for r in new_module_names: for r in new_module_names:
lora = lora_model.get_lora(r) lora = self._get_lora_layer_weights(lora_model, r)
replacement_loras.append(lora) replacement_loras.append(lora)
if lora: if lora:
has_replacement = True has_replacement = True
...@@ -637,12 +643,34 @@ class LoRAModelManager(AdapterModelManager): ...@@ -637,12 +643,34 @@ class LoRAModelManager(AdapterModelManager):
if replacement_loras[i]: if replacement_loras[i]:
continue continue
replacement_loras[i] = None replacement_loras[i] = None
# HACK Temporary solution for the pool model.
if self.is_pooling_model and not lora_model.check_lora_name(
module_name):
replaced_module_name = module_name.replace("model.", "")
if lora_model.check_lora_name(module_name):
module_name = replaced_module_name
lora_model.loras[module_name] = PackedLoRALayerWeights.pack( lora_model.loras[module_name] = PackedLoRALayerWeights.pack(
replacement_loras) replacement_loras)
# Remove the modules that have been replaced. # Remove the modules that have been replaced.
for module in replaced_module: for module in replaced_module:
lora_model.loras.pop(module, None) lora_model.loras.pop(module, None)
def _get_lora_layer_weights(
self, lora_model: LoRAModel,
module_name: str) -> Optional[LoRALayerWeights]:
org_module_name = module_name
if self.is_pooling_model and not lora_model.check_lora_name(
module_name):
# If it's a pool model, and the layer name is not found,
# remove the prefix 'model.' and search again.
module_name = module_name.replace("model.", "")
if lora_model.check_lora_name(module_name):
org_module_name = module_name
logger.info_once(
"For the pool model, successfully loaded the LoRA weights "
"after removing the prefix 'model.'.")
return lora_model.get_lora(org_module_name)
def deactivate_adapter(self, adapter_id: int) -> bool: def deactivate_adapter(self, adapter_id: int) -> bool:
return deactivate_adapter(adapter_id, self._active_adapters, return deactivate_adapter(adapter_id, self._active_adapters,
self._deactivate_adapter) self._deactivate_adapter)
......
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 1,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 1,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 16,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 1,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 1,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"256": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 32,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 16,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"4": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 1,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"4": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 1,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"4": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"24": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 1,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"256": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"512": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"1536": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"2048": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"4096": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
}
...@@ -783,8 +783,12 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -783,8 +783,12 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
**config, **config,
) )
else: else:
config = config.copy()
BLOCK_SIZE_K = config.pop("BLOCK_SIZE_K")
if block_shape is not None:
BLOCK_SIZE_K = min(BLOCK_SIZE_K, min(block_shape[0],
block_shape[1]))
fused_moe_kernel[grid]( fused_moe_kernel[grid](
A, A,
B, B,
...@@ -823,6 +827,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor, ...@@ -823,6 +827,7 @@ def invoke_fused_moe_kernel(A: torch.Tensor,
compute_type=compute_type, compute_type=compute_type,
use_fp8_w8a8=use_fp8_w8a8, use_fp8_w8a8=use_fp8_w8a8,
use_int8_w8a16=use_int8_w8a16, use_int8_w8a16=use_int8_w8a16,
BLOCK_SIZE_K=BLOCK_SIZE_K,
**config, **config,
) )
......
{
"1": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"2": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"4": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"8": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"16": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"24": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"32": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"48": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"64": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"96": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"128": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"256": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"512": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"1024": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"1536": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"2048": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"3072": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"4096": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"2": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"4": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"8": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"16": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"24": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"32": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"48": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"64": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"96": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"128": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"256": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"512": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"1024": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"1536": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"2048": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"3072": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"4096": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"2": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"4": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"8": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"16": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"24": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"32": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"48": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"64": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"96": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"128": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"256": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"512": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"1024": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"1536": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"2048": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"3072": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"4096": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
}
}
\ No newline at end of file
{
"1": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"2": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"4": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"8": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"16": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"24": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"32": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"48": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"64": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"96": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"128": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 16,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"256": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"512": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"1024": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 8,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"1536": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"2048": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 1,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"3072": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
},
"4096": {
"BLOCK_SIZE_K": 128,
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"GROUP_SIZE_M": 32,
"kpack": 1,
"matrix_instr_nonkdim": 16,
"num_warps": 4
}
}
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment