Commit 2216a4e5 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/main'

parents ad385667 51c24c97
......@@ -7,8 +7,8 @@ from typing import List, Optional, Tuple
import pytest
from transformers import AutoModelForSeq2SeqLM
from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs
from vllm.utils import is_cpu
from ..conftest import DecoderPromptType
from ..models.utils import check_logprobs_close
......@@ -35,7 +35,7 @@ def vllm_to_hf_output(
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
@pytest.mark.parametrize("enforce_eager", [True, False])
@pytest.mark.skipif(
is_cpu(),
current_platform.is_cpu(),
reason="CPU backend is not currently supported with encoder/decoder models"
)
def test_encoder_decoder_e2e(
......
from typing import List
import pytest
from vllm import LLM
from ..openai.test_vision import TEST_IMAGE_URLS
def test_chat():
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct")
prompt1 = "Explain the concept of entropy."
messages = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt1
},
]
outputs = llm.chat(messages)
assert len(outputs) == 1
def test_multi_chat():
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct")
prompt1 = "Explain the concept of entropy."
prompt2 = "Explain what among us is."
conversation1 = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt1
},
]
conversation2 = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt2
},
]
messages = [conversation1, conversation2]
outputs = llm.chat(messages)
assert len(outputs) == 2
@pytest.mark.parametrize("image_urls",
[[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]])
def test_chat_multi_image(image_urls: List[str]):
llm = LLM(
model="microsoft/Phi-3.5-vision-instruct",
dtype="bfloat16",
max_model_len=4096,
max_num_seqs=5,
enforce_eager=True,
trust_remote_code=True,
limit_mm_per_prompt={"image": 2},
)
messages = [{
"role":
"user",
"content": [
*({
"type": "image_url",
"image_url": {
"url": image_url
}
} for image_url in image_urls),
{
"type": "text",
"text": "What's in this image?"
},
],
}]
outputs = llm.chat(messages)
assert len(outputs) >= 0
......@@ -4,8 +4,7 @@ from typing import List
import pytest
from vllm import LLM, EmbeddingRequestOutput, PoolingParams
from ...conftest import cleanup
from vllm.distributed import cleanup_dist_env_and_memory
MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
......@@ -41,7 +40,7 @@ def llm():
del llm
cleanup()
cleanup_dist_env_and_memory()
def assert_outputs_equal(o1: List[EmbeddingRequestOutput],
......
......@@ -4,9 +4,7 @@ from typing import List
import pytest
from vllm import LLM, RequestOutput, SamplingParams
from ...conftest import cleanup
from ..openai.test_vision import TEST_IMAGE_URLS
from vllm.distributed import cleanup_dist_env_and_memory
MODEL_NAME = "facebook/opt-125m"
......@@ -40,7 +38,7 @@ def llm():
del llm
cleanup()
cleanup_dist_env_and_memory()
def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):
......@@ -104,90 +102,3 @@ def test_multiple_sampling_params(llm: LLM):
# sampling_params is None, default params should be applied
outputs = llm.generate(PROMPTS, sampling_params=None)
assert len(PROMPTS) == len(outputs)
def test_chat():
llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
prompt1 = "Explain the concept of entropy."
messages = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt1
},
]
outputs = llm.chat(messages)
assert len(outputs) == 1
def test_multi_chat():
llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
prompt1 = "Explain the concept of entropy."
prompt2 = "Explain what among us is."
conversation1 = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt1
},
]
conversation2 = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt2
},
]
messages = [conversation1, conversation2]
outputs = llm.chat(messages)
assert len(outputs) == 2
@pytest.mark.parametrize("image_urls",
[[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]])
def test_chat_multi_image(image_urls: List[str]):
llm = LLM(
model="microsoft/Phi-3.5-vision-instruct",
dtype="bfloat16",
max_model_len=4096,
max_num_seqs=5,
enforce_eager=True,
trust_remote_code=True,
limit_mm_per_prompt={"image": 2},
)
messages = [{
"role":
"user",
"content": [
*({
"type": "image_url",
"image_url": {
"url": image_url
}
} for image_url in image_urls),
{
"type": "text",
"text": "What's in this image?"
},
],
}]
outputs = llm.chat(messages)
assert len(outputs) >= 0
......@@ -5,10 +5,9 @@ import pytest
from huggingface_hub import snapshot_download
from vllm import LLM
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.lora.request import LoRARequest
from ...conftest import cleanup
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
PROMPTS = [
......@@ -39,7 +38,7 @@ def llm():
del llm
cleanup()
cleanup_dist_env_and_memory()
@pytest.fixture(scope="module")
......
......@@ -5,12 +5,11 @@ import weakref
import jsonschema
import pytest
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from ...conftest import cleanup
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
......@@ -23,7 +22,7 @@ def llm():
with llm.deprecate_legacy_api():
yield weakref.proxy(llm)
del llm
cleanup()
cleanup_dist_env_and_memory()
@pytest.mark.skip_global_cleanup
......
import pytest
from vllm import LLM
from ...utils import error_on_warning
MODEL_NAME = "facebook/opt-125m"
def test_pos_args_deprecated():
with error_on_warning(DeprecationWarning):
LLM(model=MODEL_NAME, tokenizer=MODEL_NAME)
with error_on_warning(DeprecationWarning):
LLM(MODEL_NAME, tokenizer=MODEL_NAME)
with pytest.warns(DeprecationWarning, match="'tokenizer'"):
LLM(MODEL_NAME, MODEL_NAME)
with pytest.warns(DeprecationWarning,
match="'tokenizer', 'tokenizer_mode'"):
LLM(MODEL_NAME, MODEL_NAME, "auto")
import sys
from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
def test_lazy_outlines(sample_regex):
......@@ -14,6 +15,7 @@ def test_lazy_outlines(sample_regex):
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM without guided decoding as a baseline.
llm = LLM(model="facebook/opt-125m",
enforce_eager=True,
gpu_memory_utilization=0.3)
......@@ -26,10 +28,15 @@ def test_lazy_outlines(sample_regex):
# make sure outlines is not imported
assert 'outlines' not in sys.modules
# Destroy the LLM object and free up the GPU memory.
del llm
cleanup_dist_env_and_memory()
# Create an LLM with guided decoding enabled.
llm = LLM(model="facebook/opt-125m",
enforce_eager=True,
guided_decoding_backend="lm-format-enforcer",
gpu_memory_utilization=0.3)
gpu_memory_utilization=0.6)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
outputs = llm.generate(
prompts=[
......
"""Tests for HF_HUB_OFFLINE mode"""
import importlib
import sys
import weakref
import pytest
from vllm import LLM
from ...conftest import cleanup
MODEL_NAME = "facebook/opt-125m"
from vllm.distributed import cleanup_dist_env_and_memory
MODEL_CONFIGS = [
{
"model": "facebook/opt-125m",
"enforce_eager": True,
"gpu_memory_utilization": 0.20,
"max_model_len": 64,
"max_num_batched_tokens": 64,
"max_num_seqs": 64,
"tensor_parallel_size": 1,
},
{
"model": "mistralai/Mistral-7B-Instruct-v0.1",
"enforce_eager": True,
"gpu_memory_utilization": 0.95,
"max_model_len": 64,
"max_num_batched_tokens": 64,
"max_num_seqs": 64,
"tensor_parallel_size": 1,
"tokenizer_mode": "mistral",
},
]
@pytest.fixture(scope="module")
def llm():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(model=MODEL_NAME,
max_num_batched_tokens=4096,
tensor_parallel_size=1,
gpu_memory_utilization=0.10,
enforce_eager=True)
with llm.deprecate_legacy_api():
yield weakref.proxy(llm)
def cache_models():
# Cache model files first
for model_config in MODEL_CONFIGS:
LLM(**model_config)
cleanup_dist_env_and_memory()
del llm
cleanup()
yield
@pytest.mark.skip_global_cleanup
def test_offline_mode(llm: LLM, monkeypatch):
# we use the llm fixture to ensure the model files are in-cache
del llm
@pytest.mark.usefixtures("cache_models")
def test_offline_mode(monkeypatch):
# Set HF to offline mode and ensure we can still construct an LLM
try:
monkeypatch.setenv("HF_HUB_OFFLINE", "1")
# Need to re-import huggingface_hub and friends to setup offline mode
_re_import_modules()
# Cached model files should be used in offline mode
LLM(model=MODEL_NAME,
max_num_batched_tokens=4096,
tensor_parallel_size=1,
gpu_memory_utilization=0.10,
enforce_eager=True)
for model_config in MODEL_CONFIGS:
LLM(**model_config)
finally:
# Reset the environment after the test
# NB: Assuming tests are run in online mode
......
......@@ -16,9 +16,6 @@ from .test_completion import zephyr_lora_files # noqa: F401
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
# generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora"
@pytest.fixture(scope="module")
......@@ -851,14 +848,28 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI):
@pytest.mark.asyncio
async def test_response_format_json_schema(client: openai.AsyncOpenAI):
prompt = 'what is 1+1? The format is "result": 2'
# Check that this prompt cannot lead to a valid JSON without json_schema
for _ in range(2):
resp = await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role":
"user",
"content": ('what is 1+1? please respond with a JSON object, '
'the format is {"result": 2}')
"role": "user",
"content": prompt
}],
)
content = resp.choices[0].message.content
assert content is not None
with pytest.raises((json.JSONDecodeError, AssertionError)):
loaded = json.loads(content)
assert loaded == {"result": 2}, loaded
for _ in range(2):
resp = await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": prompt
}],
response_format={
"type": "json_schema",
......
......@@ -340,6 +340,40 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
assert "".join(chunks) == single_output
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
"""Streaming for parallel sampling.
The tokens from multiple samples, are flattened into a single stream,
with an index to indicate which sample the token belongs to.
"""
prompt = "What is an LLM?"
n = 3
max_tokens = 5
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=max_tokens,
n=n,
stream=True)
chunks: List[List[str]] = [[] for i in range(n)]
finish_reason_count = 0
async for chunk in stream:
index = chunk.choices[0].index
text = chunk.choices[0].text
chunks[index].append(text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
assert finish_reason_count == n
for chunk in chunks:
assert len(chunk) == max_tokens
print("".join(chunk))
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
......
......@@ -22,12 +22,12 @@ class MockHFConfig:
@dataclass
class MockModelConfig:
task = "generate"
tokenizer = MODEL_NAME
trust_remote_code = False
tokenizer_mode = "auto"
max_model_len = 100
tokenizer_revision = None
embedding_mode = False
multimodal_config = MultiModalConfig()
hf_config = MockHFConfig()
......
......@@ -6,7 +6,7 @@ import pytest
from ...utils import RemoteOpenAIServer
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
MODEL_NAME = "meta-llama/Llama-3.2-1B"
@pytest.mark.asyncio
......
......@@ -23,6 +23,8 @@ TEST_IMAGE_URLS = [
@pytest.fixture(scope="module")
def server():
args = [
"--task",
"generate",
"--dtype",
"bfloat16",
"--max-model-len",
......
......@@ -18,7 +18,8 @@ PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
@pytest.fixture(scope="module")
def phi3v_model_config():
return ModelConfig(PHI3V_MODEL_ID,
PHI3V_MODEL_ID,
task="generate",
tokenizer=PHI3V_MODEL_ID,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="bfloat16",
......@@ -387,3 +388,29 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages(
"text": "What about these two?"
}]
}], phi3v_model_config, phi3v_tokenizer)
def test_parse_chat_messages_multiple_images_uncommon_input(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_data = parse_chat_messages([{
"role":
"user",
"content": [
"What's in these images?", {
"image_url": image_url
}, {
"image_url": image_url
}
]
}], phi3v_model_config, phi3v_tokenizer)
assert conversation == [{
"role":
"user",
"content":
"<|image_1|>\n<|image_2|>\nWhat's in these images?"
}]
_assert_mm_data_is_image_input(mm_data, 2)
......@@ -19,22 +19,23 @@ def test_env(name: str, device: str, monkeypatch):
override_backend_env_variable(monkeypatch, name)
if device == "cpu":
with patch("vllm.attention.selector.is_cpu", return_value=True):
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
16, False)
with patch("vllm.attention.selector.current_platform.is_cpu",
return_value=True):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.is_hip", return_value=True):
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
16, False)
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "ROCM_FLASH"
elif device == "openvino":
with patch("vllm.attention.selector.is_openvino", return_value=True):
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
16, False)
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "OPENVINO"
else:
backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16,
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == name
......@@ -46,37 +47,32 @@ def test_flash_attn(monkeypatch):
# Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
backend = which_attn_to_use(16, None, torch.float16, None, 16, False)
backend = which_attn_to_use(16, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported data type
backend = which_attn_to_use(16, None, torch.float8_e4m3fn, None, 16, False)
backend = which_attn_to_use(16, torch.float8_e4m3fn, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported kv cache data type
backend = which_attn_to_use(16, None, torch.float16, "fp8", 16, False)
backend = which_attn_to_use(16, torch.float16, "fp8", 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported block size
backend = which_attn_to_use(16, None, torch.float16, None, 8, False)
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported sliding window
backend = which_attn_to_use(16, 1, torch.float16, None, 16, False)
backend = which_attn_to_use(16, torch.float16, None, 8, False)
assert backend.name != STR_FLASH_ATTN_VAL
# flash-attn is not installed
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
backend = which_attn_to_use(16, None, torch.float16, None, 16, False)
backend = which_attn_to_use(16, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported head size
backend = which_attn_to_use(17, None, torch.float16, None, 16, False)
backend = which_attn_to_use(17, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
# Attention-free models should bypass env and use PlaceholderAttention
backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16,
True)
backend = which_attn_to_use(16, torch.float16, torch.float16, 16, True)
assert backend.name != STR_FLASH_ATTN_VAL
......@@ -84,4 +80,4 @@ def test_invalid_env(monkeypatch):
"""Throw an exception if the backend name is invalid."""
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
with pytest.raises(ValueError):
which_attn_to_use(16, None, torch.float16, None, 16, False)
\ No newline at end of file
which_attn_to_use(16, torch.float16, None, 16, False)
......@@ -78,6 +78,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("sliding_window", [None, 256])
@torch.inference_mode()
def test_flash_attn_with_paged_kv(
kv_lens: List[int],
......@@ -87,6 +88,7 @@ def test_flash_attn_with_paged_kv(
block_size: int,
soft_cap: Optional[float],
num_blocks: int,
sliding_window: Optional[int],
) -> None:
torch.set_default_device("cuda")
seed_everything(0)
......@@ -96,6 +98,8 @@ def test_flash_attn_with_paged_kv(
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
scale = head_size**-0.5
window_size = ((sliding_window - 1, 0) if sliding_window is not None else
(-1, -1))
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
key_cache = torch.randn(num_blocks,
......@@ -121,10 +125,10 @@ def test_flash_attn_with_paged_kv(
block_table=block_tables,
cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
window_size=window_size,
).squeeze(1)
ref_output = ref_paged_attn(
query=query,
ref_output = ref_paged_attn(query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=[1] * num_seqs,
......@@ -132,7 +136,7 @@ def test_flash_attn_with_paged_kv(
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap,
)
sliding_window=sliding_window)
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
......@@ -141,7 +145,7 @@ def test_flash_attn_with_paged_kv(
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("sliding_window", [None])
@pytest.mark.parametrize("sliding_window", [None, 256])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
......@@ -166,8 +170,7 @@ def test_varlen_with_paged_kv(
assert num_query_heads % num_kv_heads == 0
max_query_len = max(query_lens)
max_kv_len = max(kv_lens)
window_size = ((sliding_window,
sliding_window) if sliding_window is not None else
window_size = ((sliding_window - 1, 0) if sliding_window is not None else
(-1, -1))
scale = head_size**-0.5
......
......@@ -80,7 +80,7 @@ def machete_quantize_and_pack(w: torch.Tensor,
w_q = w_q.t().contiguous().t() # convert to col major
w_q_machete = ops.machete_prepack_B(w_q, wtype)
opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype))
opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype.id))
return w_ref, w_q_machete, w_s, w_zp
......@@ -153,8 +153,9 @@ def test_machete_all_schedules(shape, atype: torch.dtype,
schedule=schedule,
)
opcheck(torch.ops._C.machete_gemm,
(a, w_q_machete, wtype, w_s, maybe_convert_zeropoints(
opcheck(
torch.ops._C.machete_gemm,
(a, w_q_machete, wtype.id, w_s, maybe_convert_zeropoints(
w_zp, w_s), group_size, None, None, None, schedule))
# Relax atol as our reduction dim becomes larger (more rounding error)
......
......@@ -225,7 +225,7 @@ def test_gptq_marlin_gemm(
opcheck(
torch.ops._C.gptq_marlin_gemm,
(a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
workspace.scratch, quant_type, a_input.shape[0], b_weight.shape[1],
workspace.scratch, quant_type.id, a_input.shape[0], b_weight.shape[1],
a_input.shape[1], is_k_full, False, use_fp32_reduce),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
......@@ -254,6 +254,16 @@ def test_gptq_marlin_gemm(
assert max_diff < 0.04
# TODO: find better way to test this?
@torch.compile(fullgraph=True)
def marlin_24_gemm_tester(a_input, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s, scratch, quant_type, size_m, size_n,
size_k):
return ops.gptq_marlin_24_gemm(a_input, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s, scratch, quant_type, size_m,
size_n, size_k)
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
......@@ -282,11 +292,11 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
opcheck(torch.ops._C.gptq_marlin_24_gemm,
(a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s,
workspace_24.scratch, quant_type, a_input.shape[0],
workspace_24.scratch, quant_type.id, a_input.shape[0],
b_weight.shape[1], a_input.shape[1]),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
output = ops.gptq_marlin_24_gemm(
output = marlin_24_gemm_tester(
a_input,
marlin_24_q_w_comp,
marlin_24_meta,
......
......@@ -240,8 +240,8 @@ def test_fused_marlin_moe(
requires_grad=False)
opcheck(torch.ops._moe_C.marlin_gemm_moe,
(a, qweight1, sorted_token_ids, topk_weights, topk_ids,
scales1, zp, g_idx1, sort_indices1, workspace, quant_type, m,
2 * n, k, True, e, topk, block_size_m, True, False))
scales1, zp, g_idx1, sort_indices1, workspace, quant_type.id,
m, 2 * n, k, True, e, topk, block_size_m, True, False))
@pytest.mark.skip("This test is here for the sake of debugging, "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment