"tests/vscode:/vscode.git/clone" did not exist on "b6087a6beead9165f4c77ceba592b3651bb37de9"
Commit 2216a4e5 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge remote-tracking branch 'mirror/main'

parents ad385667 51c24c97
......@@ -7,8 +7,8 @@ from typing import List, Optional, Tuple
import pytest
from transformers import AutoModelForSeq2SeqLM
from vllm.platforms import current_platform
from vllm.sequence import SampleLogprobs
from vllm.utils import is_cpu
from ..conftest import DecoderPromptType
from ..models.utils import check_logprobs_close
......@@ -35,7 +35,7 @@ def vllm_to_hf_output(
@pytest.mark.parametrize("decoder_prompt_type", list(DecoderPromptType))
@pytest.mark.parametrize("enforce_eager", [True, False])
@pytest.mark.skipif(
is_cpu(),
current_platform.is_cpu(),
reason="CPU backend is not currently supported with encoder/decoder models"
)
def test_encoder_decoder_e2e(
......@@ -50,7 +50,7 @@ def test_encoder_decoder_e2e(
enforce_eager: bool,
) -> None:
'''
End-to-End (E2E) test for the encoder-decoder framework.
End-to-End (E2E) test for the encoder-decoder framework.
This test evaluates the encoder-decoder functionality using the BART
model. We compare the outputs of the Hugging Face and vLLM
implementations to ensure that both implementations produce consistent
......
from typing import List
import pytest
from vllm import LLM
from ..openai.test_vision import TEST_IMAGE_URLS
def test_chat():
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct")
prompt1 = "Explain the concept of entropy."
messages = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt1
},
]
outputs = llm.chat(messages)
assert len(outputs) == 1
def test_multi_chat():
llm = LLM(model="meta-llama/Llama-3.2-1B-Instruct")
prompt1 = "Explain the concept of entropy."
prompt2 = "Explain what among us is."
conversation1 = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt1
},
]
conversation2 = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt2
},
]
messages = [conversation1, conversation2]
outputs = llm.chat(messages)
assert len(outputs) == 2
@pytest.mark.parametrize("image_urls",
[[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]])
def test_chat_multi_image(image_urls: List[str]):
llm = LLM(
model="microsoft/Phi-3.5-vision-instruct",
dtype="bfloat16",
max_model_len=4096,
max_num_seqs=5,
enforce_eager=True,
trust_remote_code=True,
limit_mm_per_prompt={"image": 2},
)
messages = [{
"role":
"user",
"content": [
*({
"type": "image_url",
"image_url": {
"url": image_url
}
} for image_url in image_urls),
{
"type": "text",
"text": "What's in this image?"
},
],
}]
outputs = llm.chat(messages)
assert len(outputs) >= 0
......@@ -4,8 +4,7 @@ from typing import List
import pytest
from vllm import LLM, EmbeddingRequestOutput, PoolingParams
from ...conftest import cleanup
from vllm.distributed import cleanup_dist_env_and_memory
MODEL_NAME = "intfloat/e5-mistral-7b-instruct"
......@@ -41,7 +40,7 @@ def llm():
del llm
cleanup()
cleanup_dist_env_and_memory()
def assert_outputs_equal(o1: List[EmbeddingRequestOutput],
......
......@@ -4,9 +4,7 @@ from typing import List
import pytest
from vllm import LLM, RequestOutput, SamplingParams
from ...conftest import cleanup
from ..openai.test_vision import TEST_IMAGE_URLS
from vllm.distributed import cleanup_dist_env_and_memory
MODEL_NAME = "facebook/opt-125m"
......@@ -40,7 +38,7 @@ def llm():
del llm
cleanup()
cleanup_dist_env_and_memory()
def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]):
......@@ -104,90 +102,3 @@ def test_multiple_sampling_params(llm: LLM):
# sampling_params is None, default params should be applied
outputs = llm.generate(PROMPTS, sampling_params=None)
assert len(PROMPTS) == len(outputs)
def test_chat():
llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
prompt1 = "Explain the concept of entropy."
messages = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt1
},
]
outputs = llm.chat(messages)
assert len(outputs) == 1
def test_multi_chat():
llm = LLM(model="meta-llama/Meta-Llama-3-8B-Instruct")
prompt1 = "Explain the concept of entropy."
prompt2 = "Explain what among us is."
conversation1 = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt1
},
]
conversation2 = [
{
"role": "system",
"content": "You are a helpful assistant"
},
{
"role": "user",
"content": prompt2
},
]
messages = [conversation1, conversation2]
outputs = llm.chat(messages)
assert len(outputs) == 2
@pytest.mark.parametrize("image_urls",
[[TEST_IMAGE_URLS[0], TEST_IMAGE_URLS[1]]])
def test_chat_multi_image(image_urls: List[str]):
llm = LLM(
model="microsoft/Phi-3.5-vision-instruct",
dtype="bfloat16",
max_model_len=4096,
max_num_seqs=5,
enforce_eager=True,
trust_remote_code=True,
limit_mm_per_prompt={"image": 2},
)
messages = [{
"role":
"user",
"content": [
*({
"type": "image_url",
"image_url": {
"url": image_url
}
} for image_url in image_urls),
{
"type": "text",
"text": "What's in this image?"
},
],
}]
outputs = llm.chat(messages)
assert len(outputs) >= 0
......@@ -5,10 +5,9 @@ import pytest
from huggingface_hub import snapshot_download
from vllm import LLM
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.lora.request import LoRARequest
from ...conftest import cleanup
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
PROMPTS = [
......@@ -39,7 +38,7 @@ def llm():
del llm
cleanup()
cleanup_dist_env_and_memory()
@pytest.fixture(scope="module")
......
......@@ -5,12 +5,11 @@ import weakref
import jsonschema
import pytest
from vllm.distributed import cleanup_dist_env_and_memory
from vllm.entrypoints.llm import LLM
from vllm.outputs import RequestOutput
from vllm.sampling_params import GuidedDecodingParams, SamplingParams
from ...conftest import cleanup
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
......@@ -23,7 +22,7 @@ def llm():
with llm.deprecate_legacy_api():
yield weakref.proxy(llm)
del llm
cleanup()
cleanup_dist_env_and_memory()
@pytest.mark.skip_global_cleanup
......
import pytest
from vllm import LLM
from ...utils import error_on_warning
MODEL_NAME = "facebook/opt-125m"
def test_pos_args_deprecated():
with error_on_warning(DeprecationWarning):
LLM(model=MODEL_NAME, tokenizer=MODEL_NAME)
with error_on_warning(DeprecationWarning):
LLM(MODEL_NAME, tokenizer=MODEL_NAME)
with pytest.warns(DeprecationWarning, match="'tokenizer'"):
LLM(MODEL_NAME, MODEL_NAME)
with pytest.warns(DeprecationWarning,
match="'tokenizer', 'tokenizer_mode'"):
LLM(MODEL_NAME, MODEL_NAME, "auto")
import sys
from vllm import LLM, SamplingParams
from vllm.distributed import cleanup_dist_env_and_memory
def test_lazy_outlines(sample_regex):
......@@ -14,6 +15,7 @@ def test_lazy_outlines(sample_regex):
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
# Create an LLM without guided decoding as a baseline.
llm = LLM(model="facebook/opt-125m",
enforce_eager=True,
gpu_memory_utilization=0.3)
......@@ -26,10 +28,15 @@ def test_lazy_outlines(sample_regex):
# make sure outlines is not imported
assert 'outlines' not in sys.modules
# Destroy the LLM object and free up the GPU memory.
del llm
cleanup_dist_env_and_memory()
# Create an LLM with guided decoding enabled.
llm = LLM(model="facebook/opt-125m",
enforce_eager=True,
guided_decoding_backend="lm-format-enforcer",
gpu_memory_utilization=0.3)
gpu_memory_utilization=0.6)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
outputs = llm.generate(
prompts=[
......
"""Tests for HF_HUB_OFFLINE mode"""
import importlib
import sys
import weakref
import pytest
from vllm import LLM
from ...conftest import cleanup
MODEL_NAME = "facebook/opt-125m"
from vllm.distributed import cleanup_dist_env_and_memory
MODEL_CONFIGS = [
{
"model": "facebook/opt-125m",
"enforce_eager": True,
"gpu_memory_utilization": 0.20,
"max_model_len": 64,
"max_num_batched_tokens": 64,
"max_num_seqs": 64,
"tensor_parallel_size": 1,
},
{
"model": "mistralai/Mistral-7B-Instruct-v0.1",
"enforce_eager": True,
"gpu_memory_utilization": 0.95,
"max_model_len": 64,
"max_num_batched_tokens": 64,
"max_num_seqs": 64,
"tensor_parallel_size": 1,
"tokenizer_mode": "mistral",
},
]
@pytest.fixture(scope="module")
def llm():
# pytest caches the fixture so we use weakref.proxy to
# enable garbage collection
llm = LLM(model=MODEL_NAME,
max_num_batched_tokens=4096,
tensor_parallel_size=1,
gpu_memory_utilization=0.10,
enforce_eager=True)
with llm.deprecate_legacy_api():
yield weakref.proxy(llm)
def cache_models():
# Cache model files first
for model_config in MODEL_CONFIGS:
LLM(**model_config)
cleanup_dist_env_and_memory()
del llm
cleanup()
yield
@pytest.mark.skip_global_cleanup
def test_offline_mode(llm: LLM, monkeypatch):
# we use the llm fixture to ensure the model files are in-cache
del llm
@pytest.mark.usefixtures("cache_models")
def test_offline_mode(monkeypatch):
# Set HF to offline mode and ensure we can still construct an LLM
try:
monkeypatch.setenv("HF_HUB_OFFLINE", "1")
# Need to re-import huggingface_hub and friends to setup offline mode
_re_import_modules()
# Cached model files should be used in offline mode
LLM(model=MODEL_NAME,
max_num_batched_tokens=4096,
tensor_parallel_size=1,
gpu_memory_utilization=0.10,
enforce_eager=True)
for model_config in MODEL_CONFIGS:
LLM(**model_config)
finally:
# Reset the environment after the test
# NB: Assuming tests are run in online mode
......
......@@ -16,9 +16,6 @@ from .test_completion import zephyr_lora_files # noqa: F401
# any model with a chat template should work here
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
# generation quality here
LORA_NAME = "typeof/zephyr-7b-beta-lora"
@pytest.fixture(scope="module")
......@@ -851,14 +848,28 @@ async def test_response_format_json_object(client: openai.AsyncOpenAI):
@pytest.mark.asyncio
async def test_response_format_json_schema(client: openai.AsyncOpenAI):
prompt = 'what is 1+1? The format is "result": 2'
# Check that this prompt cannot lead to a valid JSON without json_schema
for _ in range(2):
resp = await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role":
"user",
"content": ('what is 1+1? please respond with a JSON object, '
'the format is {"result": 2}')
"role": "user",
"content": prompt
}],
)
content = resp.choices[0].message.content
assert content is not None
with pytest.raises((json.JSONDecodeError, AssertionError)):
loaded = json.loads(content)
assert loaded == {"result": 2}, loaded
for _ in range(2):
resp = await client.chat.completions.create(
model=MODEL_NAME,
messages=[{
"role": "user",
"content": prompt
}],
response_format={
"type": "json_schema",
......
......@@ -340,6 +340,40 @@ async def test_completion_streaming(client: openai.AsyncOpenAI,
assert "".join(chunks) == single_output
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME, "zephyr-lora", "zephyr-pa"],
)
async def test_parallel_streaming(client: openai.AsyncOpenAI, model_name: str):
"""Streaming for parallel sampling.
The tokens from multiple samples, are flattened into a single stream,
with an index to indicate which sample the token belongs to.
"""
prompt = "What is an LLM?"
n = 3
max_tokens = 5
stream = await client.completions.create(model=model_name,
prompt=prompt,
max_tokens=max_tokens,
n=n,
stream=True)
chunks: List[List[str]] = [[] for i in range(n)]
finish_reason_count = 0
async for chunk in stream:
index = chunk.choices[0].index
text = chunk.choices[0].text
chunks[index].append(text)
if chunk.choices[0].finish_reason is not None:
finish_reason_count += 1
assert finish_reason_count == n
for chunk in chunks:
assert len(chunk) == max_tokens
print("".join(chunk))
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
......
......@@ -22,12 +22,12 @@ class MockHFConfig:
@dataclass
class MockModelConfig:
task = "generate"
tokenizer = MODEL_NAME
trust_remote_code = False
tokenizer_mode = "auto"
max_model_len = 100
tokenizer_revision = None
embedding_mode = False
multimodal_config = MultiModalConfig()
hf_config = MockHFConfig()
......
......@@ -6,7 +6,7 @@ import pytest
from ...utils import RemoteOpenAIServer
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
MODEL_NAME = "meta-llama/Llama-3.2-1B"
@pytest.mark.asyncio
......
......@@ -23,6 +23,8 @@ TEST_IMAGE_URLS = [
@pytest.fixture(scope="module")
def server():
args = [
"--task",
"generate",
"--dtype",
"bfloat16",
"--max-model-len",
......
......@@ -18,7 +18,8 @@ PHI3V_MODEL_ID = "microsoft/Phi-3.5-vision-instruct"
@pytest.fixture(scope="module")
def phi3v_model_config():
return ModelConfig(PHI3V_MODEL_ID,
PHI3V_MODEL_ID,
task="generate",
tokenizer=PHI3V_MODEL_ID,
tokenizer_mode="auto",
trust_remote_code=True,
dtype="bfloat16",
......@@ -387,3 +388,29 @@ def test_parse_chat_messages_rejects_too_many_images_across_messages(
"text": "What about these two?"
}]
}], phi3v_model_config, phi3v_tokenizer)
def test_parse_chat_messages_multiple_images_uncommon_input(
phi3v_model_config,
phi3v_tokenizer,
image_url,
):
conversation, mm_data = parse_chat_messages([{
"role":
"user",
"content": [
"What's in these images?", {
"image_url": image_url
}, {
"image_url": image_url
}
]
}], phi3v_model_config, phi3v_tokenizer)
assert conversation == [{
"role":
"user",
"content":
"<|image_1|>\n<|image_2|>\nWhat's in these images?"
}]
_assert_mm_data_is_image_input(mm_data, 2)
......@@ -19,22 +19,23 @@ def test_env(name: str, device: str, monkeypatch):
override_backend_env_variable(monkeypatch, name)
if device == "cpu":
with patch("vllm.attention.selector.is_cpu", return_value=True):
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
16, False)
with patch("vllm.attention.selector.current_platform.is_cpu",
return_value=True):
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "TORCH_SDPA"
elif device == "hip":
with patch("vllm.attention.selector.is_hip", return_value=True):
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
16, False)
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "ROCM_FLASH"
elif device == "openvino":
with patch("vllm.attention.selector.is_openvino", return_value=True):
backend = which_attn_to_use(16, None, torch.float16, torch.float16,
16, False)
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == "OPENVINO"
else:
backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16,
backend = which_attn_to_use(16, torch.float16, torch.float16, 16,
False)
assert backend.name == name
......@@ -46,37 +47,32 @@ def test_flash_attn(monkeypatch):
# Unsupported CUDA arch
with patch("torch.cuda.get_device_capability", return_value=(7, 5)):
backend = which_attn_to_use(16, None, torch.float16, None, 16, False)
backend = which_attn_to_use(16, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported data type
backend = which_attn_to_use(16, None, torch.float8_e4m3fn, None, 16, False)
backend = which_attn_to_use(16, torch.float8_e4m3fn, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported kv cache data type
backend = which_attn_to_use(16, None, torch.float16, "fp8", 16, False)
backend = which_attn_to_use(16, torch.float16, "fp8", 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported block size
backend = which_attn_to_use(16, None, torch.float16, None, 8, False)
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported sliding window
backend = which_attn_to_use(16, 1, torch.float16, None, 16, False)
backend = which_attn_to_use(16, torch.float16, None, 8, False)
assert backend.name != STR_FLASH_ATTN_VAL
# flash-attn is not installed
with patch.dict('sys.modules', {'vllm_flash_attn': None}):
backend = which_attn_to_use(16, None, torch.float16, None, 16, False)
backend = which_attn_to_use(16, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
# Unsupported head size
backend = which_attn_to_use(17, None, torch.float16, None, 16, False)
backend = which_attn_to_use(17, torch.float16, None, 16, False)
assert backend.name != STR_FLASH_ATTN_VAL
# Attention-free models should bypass env and use PlaceholderAttention
backend = which_attn_to_use(16, None, torch.float16, torch.float16, 16,
True)
backend = which_attn_to_use(16, torch.float16, torch.float16, 16, True)
assert backend.name != STR_FLASH_ATTN_VAL
......@@ -84,4 +80,4 @@ def test_invalid_env(monkeypatch):
"""Throw an exception if the backend name is invalid."""
override_backend_env_variable(monkeypatch, STR_INVALID_VAL)
with pytest.raises(ValueError):
which_attn_to_use(16, None, torch.float16, None, 16, False)
\ No newline at end of file
which_attn_to_use(16, torch.float16, None, 16, False)
......@@ -78,6 +78,7 @@ def ref_paged_attn(
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
@pytest.mark.parametrize("sliding_window", [None, 256])
@torch.inference_mode()
def test_flash_attn_with_paged_kv(
kv_lens: List[int],
......@@ -87,6 +88,7 @@ def test_flash_attn_with_paged_kv(
block_size: int,
soft_cap: Optional[float],
num_blocks: int,
sliding_window: Optional[int],
) -> None:
torch.set_default_device("cuda")
seed_everything(0)
......@@ -96,6 +98,8 @@ def test_flash_attn_with_paged_kv(
assert num_query_heads % num_kv_heads == 0
max_kv_len = max(kv_lens)
scale = head_size**-0.5
window_size = ((sliding_window - 1, 0) if sliding_window is not None else
(-1, -1))
query = torch.randn(num_seqs, num_query_heads, head_size, dtype=dtype)
key_cache = torch.randn(num_blocks,
......@@ -121,18 +125,18 @@ def test_flash_attn_with_paged_kv(
block_table=block_tables,
cache_seqlens=kv_lens_tensor,
softcap=soft_cap if soft_cap is not None else 0,
window_size=window_size,
).squeeze(1)
ref_output = ref_paged_attn(
query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=[1] * num_seqs,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap,
)
ref_output = ref_paged_attn(query=query,
key_cache=key_cache,
value_cache=value_cache,
query_lens=[1] * num_seqs,
kv_lens=kv_lens,
block_tables=block_tables,
scale=scale,
soft_cap=soft_cap,
sliding_window=sliding_window)
torch.testing.assert_close(output, ref_output, atol=2e-2, rtol=1e-2), \
f"{torch.max(torch.abs(output - ref_output))}"
......@@ -141,7 +145,7 @@ def test_flash_attn_with_paged_kv(
@pytest.mark.parametrize("num_heads", NUM_HEADS)
@pytest.mark.parametrize("head_size", HEAD_SIZES)
@pytest.mark.parametrize("block_size", BLOCK_SIZES)
@pytest.mark.parametrize("sliding_window", [None])
@pytest.mark.parametrize("sliding_window", [None, 256])
@pytest.mark.parametrize("dtype", DTYPES)
@pytest.mark.parametrize("soft_cap", [None, 10.0, 50.0])
@pytest.mark.parametrize("num_blocks", NUM_BLOCKS)
......@@ -166,8 +170,7 @@ def test_varlen_with_paged_kv(
assert num_query_heads % num_kv_heads == 0
max_query_len = max(query_lens)
max_kv_len = max(kv_lens)
window_size = ((sliding_window,
sliding_window) if sliding_window is not None else
window_size = ((sliding_window - 1, 0) if sliding_window is not None else
(-1, -1))
scale = head_size**-0.5
......
......@@ -80,7 +80,7 @@ def machete_quantize_and_pack(w: torch.Tensor,
w_q = w_q.t().contiguous().t() # convert to col major
w_q_machete = ops.machete_prepack_B(w_q, wtype)
opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype))
opcheck(torch.ops._C.machete_prepack_B, (w_q, wtype.id))
return w_ref, w_q_machete, w_s, w_zp
......@@ -153,9 +153,10 @@ def test_machete_all_schedules(shape, atype: torch.dtype,
schedule=schedule,
)
opcheck(torch.ops._C.machete_gemm,
(a, w_q_machete, wtype, w_s, maybe_convert_zeropoints(
w_zp, w_s), group_size, None, None, None, schedule))
opcheck(
torch.ops._C.machete_gemm,
(a, w_q_machete, wtype.id, w_s, maybe_convert_zeropoints(
w_zp, w_s), group_size, None, None, None, schedule))
# Relax atol as our reduction dim becomes larger (more rounding error)
# Relax atol when we have zeropoints since the way machete applies
......
......@@ -225,7 +225,7 @@ def test_gptq_marlin_gemm(
opcheck(
torch.ops._C.gptq_marlin_gemm,
(a_input, marlin_q_w, marlin_s, marlin_zp, g_idx, sort_indices,
workspace.scratch, quant_type, a_input.shape[0], b_weight.shape[1],
workspace.scratch, quant_type.id, a_input.shape[0], b_weight.shape[1],
a_input.shape[1], is_k_full, False, use_fp32_reduce),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
......@@ -254,6 +254,16 @@ def test_gptq_marlin_gemm(
assert max_diff < 0.04
# TODO: find better way to test this?
@torch.compile(fullgraph=True)
def marlin_24_gemm_tester(a_input, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s, scratch, quant_type, size_m, size_n,
size_k):
return ops.gptq_marlin_24_gemm(a_input, marlin_24_q_w_comp, marlin_24_meta,
marlin_24_s, scratch, quant_type, size_m,
size_n, size_k)
@pytest.mark.skipif(not is_quant_method_supported("gptq_marlin"),
reason="Marlin is not supported on this GPU type.")
@pytest.mark.parametrize("k_chunk", MARLIN_24_K_CHUNKS)
......@@ -282,11 +292,11 @@ def test_gptq_marlin_24_gemm(k_chunk, n_chunk, quant_type, group_size,
opcheck(torch.ops._C.gptq_marlin_24_gemm,
(a_input, marlin_24_q_w_comp, marlin_24_meta, marlin_24_s,
workspace_24.scratch, quant_type, a_input.shape[0],
workspace_24.scratch, quant_type.id, a_input.shape[0],
b_weight.shape[1], a_input.shape[1]),
test_utils=DEFAULT_OPCHECK_TEST_UTILS)
output = ops.gptq_marlin_24_gemm(
output = marlin_24_gemm_tester(
a_input,
marlin_24_q_w_comp,
marlin_24_meta,
......
......@@ -240,8 +240,8 @@ def test_fused_marlin_moe(
requires_grad=False)
opcheck(torch.ops._moe_C.marlin_gemm_moe,
(a, qweight1, sorted_token_ids, topk_weights, topk_ids,
scales1, zp, g_idx1, sort_indices1, workspace, quant_type, m,
2 * n, k, True, e, topk, block_size_m, True, False))
scales1, zp, g_idx1, sort_indices1, workspace, quant_type.id,
m, 2 * n, k, True, e, topk, block_size_m, True, False))
@pytest.mark.skip("This test is here for the sake of debugging, "
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment