Commit 78800ecf authored by zhuwenwen's avatar zhuwenwen
Browse files

skip tests about cross-attn and not supported

parent 129fce94
...@@ -14,6 +14,8 @@ from vllm.utils import is_cpu ...@@ -14,6 +14,8 @@ from vllm.utils import is_cpu
from ..conftest import DecoderPromptType from ..conftest import DecoderPromptType
from ..models.utils import check_logprobs_close from ..models.utils import check_logprobs_close
from ..utils import models_path_prefix from ..utils import models_path_prefix
from vllm.utils import is_hip
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
def vllm_to_hf_output( def vllm_to_hf_output(
...@@ -30,6 +32,8 @@ def vllm_to_hf_output( ...@@ -30,6 +32,8 @@ def vllm_to_hf_output(
return output_ids, hf_output_str, out_logprobs return output_ids, hf_output_str, out_logprobs
@pytest.mark.skipif(is_hip(),
reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.parametrize("model", [os.path.join(models_path_prefix, "facebook/bart-large-cnn")]) @pytest.mark.parametrize("model", [os.path.join(models_path_prefix, "facebook/bart-large-cnn")])
@pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("dtype", ["bfloat16"])
@pytest.mark.parametrize("max_tokens", [128]) @pytest.mark.parametrize("max_tokens", [128])
......
...@@ -4,6 +4,8 @@ import os ...@@ -4,6 +4,8 @@ import os
import pytest_asyncio import pytest_asyncio
from ...utils import RemoteOpenAIServer, models_path_prefix from ...utils import RemoteOpenAIServer, models_path_prefix
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
from vllm.utils import is_hip
MODEL_NAME = os.path.join(models_path_prefix, "facebook/bart-base") MODEL_NAME = os.path.join(models_path_prefix, "facebook/bart-base")
...@@ -26,6 +28,8 @@ async def client(server): ...@@ -26,6 +28,8 @@ async def client(server):
yield async_client yield async_client
@pytest.mark.skipif(is_hip(),
reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.parametrize("model_name", [MODEL_NAME]) @pytest.mark.parametrize("model_name", [MODEL_NAME])
async def test_single_completion(client: openai.AsyncOpenAI, model_name: str): async def test_single_completion(client: openai.AsyncOpenAI, model_name: str):
......
...@@ -13,6 +13,7 @@ from vllm.sequence import SampleLogprobs ...@@ -13,6 +13,7 @@ from vllm.sequence import SampleLogprobs
from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner from ....conftest import IMAGE_ASSETS, HfRunner, VllmRunner
from ...utils import check_logprobs_close from ...utils import check_logprobs_close
from ....utils import models_path_prefix from ....utils import models_path_prefix
from vllm.utils import is_hip
# The image token is placed before "user" on purpose so that the test can pass # The image token is placed before "user" on purpose so that the test can pass
...@@ -122,6 +123,8 @@ def run_test( ...@@ -122,6 +123,8 @@ def run_test(
) )
@pytest.mark.skipif(is_hip(),
reason="Xformers backend is not supported on ROCm.")
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"size_factors", "size_factors",
...@@ -161,6 +164,8 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, ...@@ -161,6 +164,8 @@ def test_models(hf_runner, vllm_runner, image_assets, model, size_factors,
) )
@pytest.mark.skipif(is_hip(),
reason="Xformers backend is not supported on ROCm.")
@pytest.mark.parametrize("model", models) @pytest.mark.parametrize("model", models)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"size_factors", "size_factors",
......
...@@ -7,6 +7,7 @@ import pytest ...@@ -7,6 +7,7 @@ import pytest
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
from ....utils import models_path_prefix from ....utils import models_path_prefix
from vllm.utils import is_hip
MODELS = [ MODELS = [
os.path.join(models_path_prefix, "intfloat/e5-mistral-7b-instruct"), os.path.join(models_path_prefix, "intfloat/e5-mistral-7b-instruct"),
...@@ -21,6 +22,8 @@ def compare_embeddings(embeddings1, embeddings2): ...@@ -21,6 +22,8 @@ def compare_embeddings(embeddings1, embeddings2):
return similarities return similarities
@pytest.mark.skipif(is_hip(),
reason="Consistent with NV.")
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("dtype", ["half"])
def test_models( def test_models(
......
...@@ -7,6 +7,8 @@ from typing import List, Optional, Tuple, Type ...@@ -7,6 +7,8 @@ from typing import List, Optional, Tuple, Type
from vllm.utils import is_cpu from vllm.utils import is_cpu
from ....utils import models_path_prefix from ....utils import models_path_prefix
from vllm.utils import is_hip
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
if not is_cpu(): if not is_cpu():
# CPU backend is not currently supported with encoder/decoder models # CPU backend is not currently supported with encoder/decoder models
...@@ -22,6 +24,7 @@ if not is_cpu(): ...@@ -22,6 +24,7 @@ if not is_cpu():
HfRunner, VllmRunner) HfRunner, VllmRunner)
from ....utils import multi_gpu_test from ....utils import multi_gpu_test
from ...utils import check_logprobs_close from ...utils import check_logprobs_close
from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP
MODELS = [os.path.join(models_path_prefix, "facebook/bart-base"), os.path.join(models_path_prefix, "facebook/bart-large-cnn")] MODELS = [os.path.join(models_path_prefix, "facebook/bart-base"), os.path.join(models_path_prefix, "facebook/bart-large-cnn")]
...@@ -178,6 +181,8 @@ if not is_cpu(): ...@@ -178,6 +181,8 @@ if not is_cpu():
num_outputs_0_skip_tokens=hf_skip_tokens, num_outputs_0_skip_tokens=hf_skip_tokens,
) )
@pytest.mark.skipif(is_hip(),
reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("model", MODELS)
@pytest.mark.parametrize("dtype", ["float", "bfloat16"]) @pytest.mark.parametrize("dtype", ["float", "bfloat16"])
@pytest.mark.parametrize("max_tokens", [64]) @pytest.mark.parametrize("max_tokens", [64])
...@@ -199,6 +204,8 @@ if not is_cpu(): ...@@ -199,6 +204,8 @@ if not is_cpu():
tensor_parallel_size=1, tensor_parallel_size=1,
) )
@pytest.mark.skipif(is_hip(),
reason=STR_NOT_IMPL_ENC_DEC_ROCM_HIP)
@multi_gpu_test(num_gpus=2) @multi_gpu_test(num_gpus=2)
@pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"]) @pytest.mark.parametrize("distributed_executor_backend", ["ray", "mp"])
@pytest.mark.parametrize("model", ["facebook/bart-large-cnn"]) @pytest.mark.parametrize("model", ["facebook/bart-large-cnn"])
......
...@@ -14,6 +14,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso ...@@ -14,6 +14,7 @@ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tenso
from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( from vllm.model_executor.layers.quantization.compressed_tensors.utils import (
QuantizationType) QuantizationType)
from ..utils import models_path_prefix from ..utils import models_path_prefix
from vllm.utils import is_hip
@pytest.mark.parametrize("model_args", [ @pytest.mark.parametrize("model_args", [
...@@ -91,6 +92,8 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args): ...@@ -91,6 +92,8 @@ def test_compressed_tensors_w8a8_dynanmic_per_token(vllm_runner, model_args):
assert output assert output
@pytest.mark.skipif(is_hip(),
reason="WNA16 is not supported on ROCm.")
@pytest.mark.parametrize( @pytest.mark.parametrize(
"wNa16_args", "wNa16_args",
[(os.path.join(models_path_prefix,"nm-testing/tinyllama-oneshot-w4a16-channel-v2"), "channel", None, 8), [(os.path.join(models_path_prefix,"nm-testing/tinyllama-oneshot-w4a16-channel-v2"), "channel", None, 8),
...@@ -117,6 +120,8 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args): ...@@ -117,6 +120,8 @@ def test_compressed_tensors_wNa16(vllm_runner, wNa16_args):
assert output assert output
@pytest.mark.skipif(is_hip(),
reason="W4A16 MARLIN is not supported on ROCm.")
def test_compressed_tensors_w4a16_marlin24(vllm_runner): def test_compressed_tensors_w4a16_marlin24(vllm_runner):
model_path = os.path.join(models_path_prefix,"nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t") model_path = os.path.join(models_path_prefix,"nm-testing/llama7b-one-shot-2_4-w4a16-marlin24-t")
with vllm_runner(model_path) as llm: with vllm_runner(model_path) as llm:
...@@ -133,6 +138,8 @@ def test_compressed_tensors_w4a16_marlin24(vllm_runner): ...@@ -133,6 +138,8 @@ def test_compressed_tensors_w4a16_marlin24(vllm_runner):
assert output assert output
@pytest.mark.skipif(is_hip(),
reason="FP8 is not supported on ROCm.")
def test_compressed_tensors_fp8(vllm_runner): def test_compressed_tensors_fp8(vllm_runner):
model_path = os.path.join(models_path_prefix,"nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test") model_path = os.path.join(models_path_prefix,"nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test")
with vllm_runner(model_path) as llm: with vllm_runner(model_path) as llm:
...@@ -158,6 +165,8 @@ def test_compressed_tensors_fp8(vllm_runner): ...@@ -158,6 +165,8 @@ def test_compressed_tensors_fp8(vllm_runner):
assert output assert output
@pytest.mark.skipif(is_hip(),
reason="FP8 KV cache is not supported on ROCm.")
def test_compressed_tensors_kv_cache(vllm_runner): def test_compressed_tensors_kv_cache(vllm_runner):
model_path = os.path.join(models_path_prefix,"nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme") model_path = os.path.join(models_path_prefix,"nm-testing/TinyLlama-1.1B-compressed-tensors-kv-cache-scheme")
with vllm_runner(model_path, kv_cache_dtype="fp8") as llm: with vllm_runner(model_path, kv_cache_dtype="fp8") as llm:
......
...@@ -24,31 +24,31 @@ MODEL_ARG_EXPTYPES = [ ...@@ -24,31 +24,31 @@ MODEL_ARG_EXPTYPES = [
# AUTOGPTQ # AUTOGPTQ
# compat: autogptq <=0.7.1 is_marlin_format: bool # compat: autogptq <=0.7.1 is_marlin_format: bool
# Model Serialized in Marlin Format should always use Marlin kernel. # Model Serialized in Marlin Format should always use Marlin kernel.
(os.path.join(models_path_prefix, "neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin"), None, "marlin"), # (os.path.join(models_path_prefix, "neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin"), None, "marlin"),
(os.path.join(models_path_prefix, "neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin"), "marlin", "marlin"), # (os.path.join(models_path_prefix, "neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin"), "marlin", "marlin"),
(os.path.join(models_path_prefix, "neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin"), "gptq", "marlin"), # (os.path.join(models_path_prefix, "neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin"), "gptq", "marlin"),
(os.path.join(models_path_prefix, "neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin"), "awq", "ERROR"), (os.path.join(models_path_prefix, "neuralmagic/TinyLlama-1.1B-Chat-v1.0-marlin"), "awq", "ERROR"),
# Model Serialized in Exllama Format. # Model Serialized in Exllama Format.
(os.path.join(models_path_prefix, "TheBloke/Llama-2-7B-Chat-GPTQ"), None, "gptq_marlin"), # (os.path.join(models_path_prefix, "TheBloke/Llama-2-7B-Chat-GPTQ"), None, "gptq_marlin"),
(os.path.join(models_path_prefix, "TheBloke/Llama-2-7B-Chat-GPTQ"), "marlin", "gptq_marlin"), # (os.path.join(models_path_prefix, "TheBloke/Llama-2-7B-Chat-GPTQ"), "marlin", "gptq_marlin"),
(os.path.join(models_path_prefix, "TheBloke/Llama-2-7B-Chat-GPTQ"), "gptq", "gptq"), # (os.path.join(models_path_prefix, "TheBloke/Llama-2-7B-Chat-GPTQ"), "gptq", "gptq"),
(os.path.join(models_path_prefix, "TheBloke/Llama-2-7B-Chat-GPTQ"), "awq", "ERROR"), (os.path.join(models_path_prefix, "TheBloke/Llama-2-7B-Chat-GPTQ"), "awq", "ERROR"),
# compat: autogptq >=0.8.0 use checkpoint_format: str # compat: autogptq >=0.8.0 use checkpoint_format: str
# Model Serialized in Marlin Format should always use Marlin kernel. # Model Serialized in Marlin Format should always use Marlin kernel.
(os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit"), None, "marlin"), # (os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit"), None, "marlin"),
(os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit"), "marlin", "marlin"), # (os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit"), "marlin", "marlin"),
(os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit"), "gptq", "marlin"), # (os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit"), "gptq", "marlin"),
(os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit"), "awq", "ERROR"), (os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-Marlin-4bit"), "awq", "ERROR"),
# Model Serialized in Exllama Format. # Model Serialized in Exllama Format.
(os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"), None, "gptq_marlin"), # (os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"), None, "gptq_marlin"),
(os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"), "marlin", "gptq_marlin"), # (os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"), "marlin", "gptq_marlin"),
(os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"), "gptq", "gptq"), (os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"), "gptq", "gptq"),
(os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"), "awq", "ERROR"), (os.path.join(models_path_prefix, "LnL-AI/TinyLlama-1.1B-Chat-v1.0-GPTQ-4bit"), "awq", "ERROR"),
# AUTOAWQ # AUTOAWQ
(os.path.join(models_path_prefix, "TheBloke/OpenHermes-2.5-Mistral-7B-AWQ"), None, "awq_marlin"), # (os.path.join(models_path_prefix, "TheBloke/OpenHermes-2.5-Mistral-7B-AWQ"), None, "awq_marlin"),
(os.path.join(models_path_prefix, "TheBloke/OpenHermes-2.5-Mistral-7B-AWQ"), "awq", "awq"), (os.path.join(models_path_prefix, "TheBloke/OpenHermes-2.5-Mistral-7B-AWQ"), "awq", "awq"),
(os.path.join(models_path_prefix, "TheBloke/OpenHermes-2.5-Mistral-7B-AWQ"), "marlin", "awq_marlin"), # (os.path.join(models_path_prefix, "TheBloke/OpenHermes-2.5-Mistral-7B-AWQ"), "marlin", "awq_marlin"),
(os.path.join(models_path_prefix, "TheBloke/OpenHermes-2.5-Mistral-7B-AWQ"), "gptq", "ERROR"), (os.path.join(models_path_prefix, "TheBloke/OpenHermes-2.5-Mistral-7B-AWQ"), "gptq", "ERROR"),
] ]
......
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 6,
"num_warps": 8,
"num_stages": 2
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 6,
"num_warps": 4,
"num_stages": 1
},
"6144": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 6,
"num_warps": 8,
"num_stages": 1
},
"8192": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 1
},
"12288": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 1
},
"16384": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 4,
"num_stages": 1
},
"32786": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 6,
"num_warps": 4,
"num_stages": 1
}
}
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 32,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 2
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_stages": 2
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 8,
"num_stages": 2
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"32": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1
},
"6144": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1
},
"8192": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 4,
"num_warps": 4,
"num_stages": 1
},
"12288": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 6,
"num_warps": 4,
"num_stages": 1
},
"16384": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1
},
"32786": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 1
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment