Commit 675ba75f authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.8.3' into v0.8.3-ori

parents 5cc98918 296c6572
...@@ -51,6 +51,7 @@ class ExecutorBase(ABC): ...@@ -51,6 +51,7 @@ class ExecutorBase(ABC):
self.observability_config = vllm_config.observability_config self.observability_config = vllm_config.observability_config
self._init_executor() self._init_executor()
self.is_sleeping = False self.is_sleeping = False
self.sleeping_tags: set[str] = set()
@abstractmethod @abstractmethod
def _init_executor(self) -> None: def _init_executor(self) -> None:
...@@ -204,20 +205,34 @@ class ExecutorBase(ABC): ...@@ -204,20 +205,34 @@ class ExecutorBase(ABC):
time_before_sleep = time.perf_counter() time_before_sleep = time.perf_counter()
self.collective_rpc("sleep", kwargs=dict(level=level)) self.collective_rpc("sleep", kwargs=dict(level=level))
time_after_sleep = time.perf_counter() time_after_sleep = time.perf_counter()
self.sleeping_tags = {"weights", "kv_cache"}
self.is_sleeping = True self.is_sleeping = True
logger.info("It took %.6f seconds to fall asleep.", logger.info("It took %.6f seconds to fall asleep.",
time_after_sleep - time_before_sleep) time_after_sleep - time_before_sleep)
def wake_up(self): def wake_up(self, tags: Optional[list[str]] = None):
if not self.is_sleeping: if not self.is_sleeping:
logger.warning("Executor is not sleeping.") logger.warning("Executor is not sleeping.")
return return
if tags:
for tag in tags:
if tag not in self.sleeping_tags:
logger.warning("Tag %s is not in sleeping tags %s", tag,
self.sleeping_tags)
return
time_before_wakeup = time.perf_counter() time_before_wakeup = time.perf_counter()
self.collective_rpc("wake_up") self.collective_rpc("wake_up", kwargs=dict(tags=tags))
time_after_wakeup = time.perf_counter() time_after_wakeup = time.perf_counter()
self.is_sleeping = False logger.info("It took %.6f seconds to wake up tags %s.",
logger.info("It took %.6f seconds to wake up.", time_after_wakeup - time_before_wakeup,
time_after_wakeup - time_before_wakeup) tags if tags is not None else self.sleeping_tags)
if tags:
for tag in tags:
self.sleeping_tags.remove(tag)
else:
self.sleeping_tags.clear()
if not self.sleeping_tags:
self.is_sleeping = False
def save_sharded_state( def save_sharded_state(
self, self,
......
...@@ -79,7 +79,7 @@ class RayDistributedExecutor(DistributedExecutorBase): ...@@ -79,7 +79,7 @@ class RayDistributedExecutor(DistributedExecutorBase):
# For TPU, avoid compiling NVIDIA's NCCL # For TPU, avoid compiling NVIDIA's NCCL
if current_platform.is_tpu(): if current_platform.is_tpu():
os.environ["VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL"] = "0" os.environ["VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE"] = "shm"
# If the env var is set, it uses the Ray's compiled DAG API # If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead. # which optimizes the control plane overhead.
...@@ -546,10 +546,11 @@ class RayDistributedExecutor(DistributedExecutorBase): ...@@ -546,10 +546,11 @@ class RayDistributedExecutor(DistributedExecutorBase):
"Run `pip install ray[cgraph]` to install it.") "Run `pip install ray[cgraph]` to install it.")
cupy_spec = importlib.util.find_spec("cupy") cupy_spec = importlib.util.find_spec("cupy")
if cupy_spec is None and envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL: if (cupy_spec is None
and envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE == "nccl"):
raise ValueError( raise ValueError(
"cupy is not installed but required since " "cupy is not installed but required since "
"VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL is set. " "VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE is set to 'nccl'. "
"Run `pip install ray[cgraph]` and check cupy installation.") "Run `pip install ray[cgraph]` and check cupy installation.")
def _compiled_ray_dag(self, enable_asyncio: bool): def _compiled_ray_dag(self, enable_asyncio: bool):
...@@ -557,10 +558,17 @@ class RayDistributedExecutor(DistributedExecutorBase): ...@@ -557,10 +558,17 @@ class RayDistributedExecutor(DistributedExecutorBase):
self._check_ray_cgraph_installation() self._check_ray_cgraph_installation()
from ray.dag import InputNode, MultiOutputNode from ray.dag import InputNode, MultiOutputNode
logger.info("VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL = %s", logger.info("VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE = %s",
envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL) envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE)
logger.info("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s", logger.info("VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM = %s",
envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM) envs.VLLM_USE_RAY_COMPILED_DAG_OVERLAP_COMM)
channel_type = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
if channel_type not in ("auto", "nccl", "shm"):
raise ValueError(
"Invalid value for VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE: "
f"{channel_type}. Valid values are: 'auto', 'nccl', or 'shm'.")
# Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds # Enlarge the default value of "RAY_CGRAPH_get_timeout" to 300 seconds
# (it is 10 seconds by default). This is a Ray environment variable to # (it is 10 seconds by default). This is a Ray environment variable to
# control the timeout of getting result from a compiled graph execution, # control the timeout of getting result from a compiled graph execution,
...@@ -605,13 +613,12 @@ class RayDistributedExecutor(DistributedExecutorBase): ...@@ -605,13 +613,12 @@ class RayDistributedExecutor(DistributedExecutorBase):
] ]
last_pp_rank = len(self.pp_tp_workers) - 1 last_pp_rank = len(self.pp_tp_workers) - 1
if pp_rank < last_pp_rank: if (pp_rank < last_pp_rank and
envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE != "shm"):
# Specify how intermediate tensors should be passed # Specify how intermediate tensors should be passed
# between pp stages, no need to specify for the last # between pp stages, no need to specify for the last
# pp stage. # pp stage or when using shared memory (the default).
transport = "nccl" \ transport = envs.VLLM_USE_RAY_COMPILED_DAG_CHANNEL_TYPE
if envs.VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL \
else "auto"
outputs = [ outputs = [
output.with_tensor_transport(transport=transport) output.with_tensor_transport(transport=transport)
for output in outputs for output in outputs
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
from collections.abc import Sequence from collections.abc import Sequence
from typing import Literal, TypedDict, Union, cast, overload from typing import Literal, Optional, TypedDict, Union, cast, overload
from typing_extensions import TypeIs from typing_extensions import TypeIs
from vllm.utils import is_list_of from vllm.utils import is_list_of
from .data import (EncoderDecoderInputs, ExplicitEncoderDecoderPrompt, from .data import (ExplicitEncoderDecoderPrompt, ProcessorInputs, PromptType,
ProcessorInputs, PromptType, SingletonPrompt, TextPrompt, SingletonInputs, SingletonPrompt, TextPrompt, TokensPrompt)
TokensPrompt)
class ParsedText(TypedDict): class ParsedText(TypedDict):
...@@ -110,6 +108,14 @@ def is_explicit_encoder_decoder_prompt( ...@@ -110,6 +108,14 @@ def is_explicit_encoder_decoder_prompt(
return isinstance(prompt, dict) and "encoder_prompt" in prompt return isinstance(prompt, dict) and "encoder_prompt" in prompt
def is_encoder_decoder_inputs( def split_enc_dec_inputs(
inputs: ProcessorInputs) -> TypeIs[EncoderDecoderInputs]: inputs: ProcessorInputs,
return "encoder" in inputs and "decoder" in inputs ) -> tuple[Optional[SingletonInputs], SingletonInputs]:
if "encoder" in inputs and "decoder" in inputs:
# NOTE: This passes pyright but not mypy
return (
inputs["encoder"], # type: ignore[typeddict-item]
inputs["decoder"], # type: ignore[typeddict-item]
)
return None, inputs
...@@ -261,13 +261,13 @@ class InputPreprocessor: ...@@ -261,13 +261,13 @@ class InputPreprocessor:
# initialized without a tokenizer while using also multi-modal # initialized without a tokenizer while using also multi-modal
# input. # input.
if not self.tokenizer: if not self.tokenizer:
tokenizer = None tokenizer = object() # Dummy
else: else:
tokenizer_group = self.get_tokenizer_group() tokenizer_group = self.get_tokenizer_group()
tokenizer = tokenizer_group.get_lora_tokenizer(lora_request) tokenizer = tokenizer_group.get_lora_tokenizer(lora_request)
mm_processor = self.mm_registry.create_processor( mm_processor = self.mm_registry.create_processor(self.model_config,
self.model_config, tokenizer) tokenizer=tokenizer)
if mm_processor_kwargs is None: if mm_processor_kwargs is None:
mm_processor_kwargs = {} mm_processor_kwargs = {}
...@@ -288,14 +288,14 @@ class InputPreprocessor: ...@@ -288,14 +288,14 @@ class InputPreprocessor:
# initialized without a tokenizer while using also multi-modal # initialized without a tokenizer while using also multi-modal
# input. # input.
if not self.tokenizer: if not self.tokenizer:
tokenizer = None tokenizer = object() # Dummy
else: else:
tokenizer_group = self.get_tokenizer_group() tokenizer_group = self.get_tokenizer_group()
tokenizer = await tokenizer_group.get_lora_tokenizer_async( tokenizer = await tokenizer_group.get_lora_tokenizer_async(
lora_request) lora_request)
mm_processor = self.mm_registry.create_processor( mm_processor = self.mm_registry.create_processor(self.model_config,
self.model_config, tokenizer) tokenizer=tokenizer)
if mm_processor_kwargs is None: if mm_processor_kwargs is None:
mm_processor_kwargs = {} mm_processor_kwargs = {}
...@@ -528,6 +528,7 @@ class InputPreprocessor: ...@@ -528,6 +528,7 @@ class InputPreprocessor:
prompt_token_ids=decoder_inputs_to_override[ prompt_token_ids=decoder_inputs_to_override[
"prompt_token_ids"], "prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"], mm_kwargs=inputs["mm_kwargs"],
mm_hashes=inputs["mm_hashes"],
mm_placeholders=inputs["mm_placeholders"], mm_placeholders=inputs["mm_placeholders"],
) )
else: else:
...@@ -536,6 +537,7 @@ class InputPreprocessor: ...@@ -536,6 +537,7 @@ class InputPreprocessor:
prompt=inputs["prompt"], prompt=inputs["prompt"],
prompt_token_ids=inputs["prompt_token_ids"], prompt_token_ids=inputs["prompt_token_ids"],
mm_kwargs=inputs["mm_kwargs"], mm_kwargs=inputs["mm_kwargs"],
mm_hashes=inputs["mm_hashes"],
mm_placeholders=inputs["mm_placeholders"], mm_placeholders=inputs["mm_placeholders"],
) )
elif inputs["type"] == "token": elif inputs["type"] == "token":
......
...@@ -13,13 +13,12 @@ from typing_extensions import TypeVar, assert_never ...@@ -13,13 +13,12 @@ from typing_extensions import TypeVar, assert_never
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.transformers_utils.processor import cached_processor_from_config from vllm.transformers_utils.processor import cached_processor_from_config
from vllm.transformers_utils.tokenizer import (AnyTokenizer, from vllm.transformers_utils.tokenizer import AnyTokenizer
cached_tokenizer_from_config)
from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides, from vllm.utils import (ClassRegistry, get_allowed_kwarg_only_overrides,
resolve_mm_processor_kwargs) resolve_mm_processor_kwargs)
from .data import ProcessorInputs, SingletonInputs from .data import ProcessorInputs, SingletonInputs
from .parse import is_encoder_decoder_inputs from .parse import split_enc_dec_inputs
if TYPE_CHECKING: if TYPE_CHECKING:
from vllm.config import ModelConfig from vllm.config import ModelConfig
...@@ -329,17 +328,27 @@ class InputRegistry: ...@@ -329,17 +328,27 @@ class InputRegistry:
from vllm.model_executor.model_loader import get_model_architecture from vllm.model_executor.model_loader import get_model_architecture
from vllm.multimodal import MultiModalKwargs from vllm.multimodal import MultiModalKwargs
from vllm.multimodal.profiling import MultiModalProfiler from vllm.multimodal.profiling import MultiModalProfiler
from vllm.sequence import SequenceData
if mm_registry.has_processor(model_config): if mm_registry.has_processor(model_config):
tokenizer = cached_tokenizer_from_config(model_config)
processor = mm_registry.create_processor(model_config, processor = mm_registry.create_processor(model_config,
tokenizer,
disable_cache=True) disable_cache=True)
profiler = MultiModalProfiler(processor) profiler = MultiModalProfiler(processor)
dummy_data_factory = (profiler.get_encoder_dummy_data
if is_encoder_data else dummy_data_v1 = (profiler.get_encoder_dummy_data(seq_len)
profiler.get_decoder_dummy_data) if is_encoder_data else
dummy_data = dummy_data_factory(seq_len) profiler.get_decoder_dummy_data(seq_len))
_seq_data = SequenceData.from_seqs(
dummy_data_v1.prompt_token_ids) # type: ignore[attr-defined]
dummy_data = DummyData(
seq_data=_seq_data,
multi_modal_data=getattr(dummy_data_v1, "multi_modal_data",
None),
multi_modal_placeholders=getattr(dummy_data_v1,
"multi_modal_placeholders",
None),
)
else: else:
model_cls, _ = get_model_architecture(model_config) model_cls, _ = get_model_architecture(model_config)
if is_encoder_data: if is_encoder_data:
...@@ -462,13 +471,11 @@ class InputRegistry: ...@@ -462,13 +471,11 @@ class InputRegistry:
**mm_processor_kwargs, **mm_processor_kwargs,
) )
if is_encoder_decoder_inputs(processed_inputs): encoder_inputs, decoder_inputs = split_enc_dec_inputs(processed_inputs)
self._ensure_mm_kwargs(processed_inputs["encoder"], if encoder_inputs is not None:
mm_processor_kwargs) self._ensure_mm_kwargs(encoder_inputs, mm_processor_kwargs)
self._ensure_mm_kwargs(processed_inputs["decoder"], if decoder_inputs is not None:
mm_processor_kwargs) self._ensure_mm_kwargs(decoder_inputs, mm_processor_kwargs)
else:
self._ensure_mm_kwargs(processed_inputs, mm_processor_kwargs)
return processed_inputs return processed_inputs
......
...@@ -272,7 +272,9 @@ class LoRAModel(AdapterModel): ...@@ -272,7 +272,9 @@ class LoRAModel(AdapterModel):
f" target modules in {expected_lora_modules}" f" target modules in {expected_lora_modules}"
f" but received {unexpected_modules}." f" but received {unexpected_modules}."
f" Please verify that the loaded LoRA module is correct") f" Please verify that the loaded LoRA module is correct")
tensors = torch.load(lora_bin_file_path, map_location=device) tensors = torch.load(lora_bin_file_path,
map_location=device,
weights_only=True)
else: else:
raise ValueError(f"{lora_dir} doesn't contain tensors") raise ValueError(f"{lora_dir} doesn't contain tensors")
......
...@@ -130,7 +130,7 @@ def do_expand_kernel( ...@@ -130,7 +130,7 @@ def do_expand_kernel(
# Identify A and B block pointers # Identify A and B block pointers
offset_k = tl.arange(0, BLOCK_K) offset_k = tl.arange(0, BLOCK_K)
a_ptr = (cur_input_ptr + ram[:, None] * input_d1_stride + a_ptr = (cur_input_ptr + ram[:, None] * input_d1_stride +
offset_k[None, :] * input_d2_stride, ) offset_k[None, :] * input_d2_stride)
b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index + b_ptr = (cur_lora_ptr + cur_lora_d0_stride * lora_index +
offset_k[:, None] * cur_lora_d2_stride + offset_k[:, None] * cur_lora_d2_stride +
rbn[None, :] * cur_lora_d1_stride) rbn[None, :] * cur_lora_d1_stride)
......
...@@ -136,6 +136,7 @@ def _lora_expand( ...@@ -136,6 +136,7 @@ def _lora_expand(
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1] num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1] lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
offset_start: int = 0, offset_start: int = 0,
add_inputs: bool = False, add_inputs: bool = False,
) -> None: ) -> None:
...@@ -157,11 +158,19 @@ def _lora_expand( ...@@ -157,11 +158,19 @@ def _lora_expand(
identifies the the region in token_indices_sorted_by_lora_ids that identifies the the region in token_indices_sorted_by_lora_ids that
LoRA lora_ids[i] should process. LoRA lora_ids[i] should process.
lora_ids (torch.Tensor): LoRA ids to process. lora_ids (torch.Tensor): LoRA ids to process.
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
if there are any requests that require LoRA.
offset_start (int, optional): Offset start for output_tensor. offset_start (int, optional): Offset start for output_tensor.
Defaults to 0. Defaults to 0.
add_inputs (bool, optional): Whether to add the input tensor to the add_inputs (bool, optional): Whether to add the input tensor to the
output tensor. Defaults to False. output tensor. Defaults to False.
""" """
assert no_lora_flag_cpu.numel() == 1
if no_lora_flag_cpu.item():
# None of the inputs require LoRA.
return
assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32] assert inputs.dtype in [torch.float16, torch.bfloat16, torch.float32]
for weight in lora_b_weights: for weight in lora_b_weights:
assert weight.dtype in [torch.float16, torch.bfloat16] assert weight.dtype in [torch.float16, torch.bfloat16]
...@@ -170,6 +179,8 @@ def _lora_expand( ...@@ -170,6 +179,8 @@ def _lora_expand(
assert output_tensor.is_contiguous() assert output_tensor.is_contiguous()
# metadata sanity check. # metadata sanity check.
M = inputs.size(1)
assert token_lora_mapping.size(0) == M
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size( assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
0) 0)
assert lora_ids.size(0) == num_tokens_per_lora.size(0) assert lora_ids.size(0) == num_tokens_per_lora.size(0)
...@@ -181,7 +192,6 @@ def _lora_expand( ...@@ -181,7 +192,6 @@ def _lora_expand(
inputs.device) inputs.device)
K = lora_b_weights[0].shape[-1] # K= rank K = lora_b_weights[0].shape[-1] # K= rank
M = inputs.size(1)
ADD_INPUTS = add_inputs ADD_INPUTS = add_inputs
MAX_LORAS = lora_ids.size(0) MAX_LORAS = lora_ids.size(0)
CAST_TYPE = False CAST_TYPE = False
...@@ -263,6 +273,7 @@ def _lora_expand_fake( ...@@ -263,6 +273,7 @@ def _lora_expand_fake(
num_tokens_per_lora: torch.Tensor, num_tokens_per_lora: torch.Tensor,
lora_token_start_loc: torch.Tensor, lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor, lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
offset_start: int = 0, offset_start: int = 0,
add_inputs: bool = False, add_inputs: bool = False,
) -> None: ) -> None:
......
...@@ -17,6 +17,17 @@ class LoRAKernelMeta: ...@@ -17,6 +17,17 @@ class LoRAKernelMeta:
num_tokens_per_lora: torch.Tensor num_tokens_per_lora: torch.Tensor
lora_token_start_loc: torch.Tensor lora_token_start_loc: torch.Tensor
# The V1 architecture uses the traced torch.compile graphs to execute
# a forward pass. Things to note about this process,
# 1. The tracing infers all python scalar datatype objects into a constant
# value.
# 2. The tracing cannot handle dynamic control flow. (dynamic control flow
# is an experimental feature in pytorch)
# 3. The internals of torch.ops functions are not traced.
# We disguise the "no_lora" flag as a cpu tensor and leverage point number 3
# to early exit from inside the lora_expand / lora_shrink torch operation.
no_lora_flag_cpu: torch.Tensor
@staticmethod @staticmethod
def make(max_loras: int, max_num_tokens: int, def make(max_loras: int, max_num_tokens: int,
device: Union[torch.device, str]) -> "LoRAKernelMeta": device: Union[torch.device, str]) -> "LoRAKernelMeta":
...@@ -47,17 +58,24 @@ class LoRAKernelMeta: ...@@ -47,17 +58,24 @@ class LoRAKernelMeta:
lora_token_start_loc = torch.zeros(max_loras + 2, lora_token_start_loc = torch.zeros(max_loras + 2,
dtype=torch.int32, dtype=torch.int32,
device=device) device=device)
no_lora_flag_cpu = torch.tensor([False],
dtype=torch.bool,
device='cpu')
return LoRAKernelMeta( return LoRAKernelMeta(
token_lora_mapping=token_lora_mapping, token_lora_mapping=token_lora_mapping,
token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids, token_indices_sorted_by_lora_ids=token_indices_sorted_by_lora_ids,
active_lora_ids=active_lora_ids, active_lora_ids=active_lora_ids,
num_tokens_per_lora=num_tokens_per_lora, num_tokens_per_lora=num_tokens_per_lora,
lora_token_start_loc=lora_token_start_loc) lora_token_start_loc=lora_token_start_loc,
no_lora_flag_cpu=no_lora_flag_cpu)
def _reset(self): def _reset(self):
self.active_lora_ids.fill_(-1) self.active_lora_ids.fill_(-1)
self.num_tokens_per_lora.fill_(0) self.num_tokens_per_lora.fill_(0)
self.lora_token_start_loc.fill_(0) self.lora_token_start_loc.fill_(0)
self.no_lora_flag_cpu.fill_(False)
def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None: def prepare_tensors(self, token_lora_mapping: torch.Tensor) -> None:
""" """
...@@ -70,6 +88,14 @@ class LoRAKernelMeta: ...@@ -70,6 +88,14 @@ class LoRAKernelMeta:
self._reset() self._reset()
# Check and record no-lora case.
no_lora = torch.all(token_lora_mapping == -1)
self.no_lora_flag_cpu[0] = no_lora
if no_lora:
# Early exit. LoRA kernels will not be run.
return
num_tokens = token_lora_mapping.size(0) num_tokens = token_lora_mapping.size(0)
# copy token lora mapping # copy token lora mapping
...@@ -100,7 +126,7 @@ class LoRAKernelMeta: ...@@ -100,7 +126,7 @@ class LoRAKernelMeta:
def meta_args( def meta_args(
self, token_nums: int self, token_nums: int
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
torch.Tensor]: torch.Tensor, torch.Tensor]:
""" """
This function returns the kernel metadata required for the current This function returns the kernel metadata required for the current
forward pass execution of the kernel. The function returns all the forward pass execution of the kernel. The function returns all the
...@@ -111,7 +137,11 @@ class LoRAKernelMeta: ...@@ -111,7 +137,11 @@ class LoRAKernelMeta:
token_nums (int): Number of input tokens in the current forward token_nums (int): Number of input tokens in the current forward
pass. pass.
""" """
return (self.token_lora_mapping[:token_nums], return (
self.token_indices_sorted_by_lora_ids[:token_nums], self.token_lora_mapping[:token_nums],
self.num_tokens_per_lora, self.lora_token_start_loc, self.token_indices_sorted_by_lora_ids[:token_nums],
self.active_lora_ids) self.num_tokens_per_lora,
self.lora_token_start_loc,
self.active_lora_ids,
self.no_lora_flag_cpu,
)
...@@ -106,6 +106,7 @@ def _lora_shrink( ...@@ -106,6 +106,7 @@ def _lora_shrink(
num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1] num_tokens_per_lora: torch.Tensor, # shape [max-loras + 1]
lora_token_start_loc: torch.Tensor, # shape [max-loras + 2] lora_token_start_loc: torch.Tensor, # shape [max-loras + 2]
lora_ids: torch.Tensor, # shape [max-loras + 1] lora_ids: torch.Tensor, # shape [max-loras + 1]
no_lora_flag_cpu: torch.Tensor, # shape [1]
scaling: float, scaling: float,
) -> None: ) -> None:
""" """
...@@ -126,8 +127,16 @@ def _lora_shrink( ...@@ -126,8 +127,16 @@ def _lora_shrink(
identifies the region in token_indices_sorted_by_lora_ids that identifies the region in token_indices_sorted_by_lora_ids that
LoRA lora_ids[i] should process. LoRA lora_ids[i] should process.
lora_ids (torch.Tensor): LoRA ids to process. lora_ids (torch.Tensor): LoRA ids to process.
no_lora_flag_cpu (torch.Tensor): A CPU tensor of size 1, that indicates
if there are any requests that require LoRA.
scaling (float): Scaling factor. scaling (float): Scaling factor.
""" """
assert no_lora_flag_cpu.numel() == 1
if no_lora_flag_cpu.item():
# None of the inputs require LoRA.
return
assert inputs.dtype == lora_a_weights[0].dtype assert inputs.dtype == lora_a_weights[0].dtype
assert inputs.dtype in [torch.float16, torch.bfloat16] assert inputs.dtype in [torch.float16, torch.bfloat16]
for weight in lora_a_weights: for weight in lora_a_weights:
...@@ -138,6 +147,8 @@ def _lora_shrink( ...@@ -138,6 +147,8 @@ def _lora_shrink(
assert output_tensor.is_contiguous() assert output_tensor.is_contiguous()
# metadata sanity check # metadata sanity check
M = inputs.size(0)
assert token_lora_mapping.size(0) == M
assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size( assert token_lora_mapping.size(0) == token_indices_sorted_by_lora_ids.size(
0) 0)
assert lora_ids.size(0) == num_tokens_per_lora.size(0) assert lora_ids.size(0) == num_tokens_per_lora.size(0)
...@@ -146,7 +157,6 @@ def _lora_shrink( ...@@ -146,7 +157,6 @@ def _lora_shrink(
(lora_ptr_tensor, lora_strides_d0, lora_strides_d1, (lora_ptr_tensor, lora_strides_d0, lora_strides_d1,
lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, inputs.device) lora_strides_d2) = _get_lora_a_ptr(lora_a_weights, inputs.device)
N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank N, K = lora_a_weights[0].shape[-2:] # K=hidden_size,N=rank
M = inputs.size(0)
NUM_SLICES = len(lora_a_weights) NUM_SLICES = len(lora_a_weights)
MAX_LORAS = lora_ids.size(0) MAX_LORAS = lora_ids.size(0)
...@@ -218,6 +228,7 @@ def _lora_shrink_fake( ...@@ -218,6 +228,7 @@ def _lora_shrink_fake(
num_tokens_per_lora: torch.Tensor, num_tokens_per_lora: torch.Tensor,
lora_token_start_loc: torch.Tensor, lora_token_start_loc: torch.Tensor,
lora_ids: torch.Tensor, lora_ids: torch.Tensor,
no_lora_flag_cpu: torch.Tensor,
scaling: float, scaling: float,
) -> None: ) -> None:
return return
......
...@@ -5,10 +5,10 @@ from __future__ import annotations ...@@ -5,10 +5,10 @@ from __future__ import annotations
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.reasoner import get_reasoner
from vllm.model_executor.guided_decoding.utils import ( from vllm.model_executor.guided_decoding.utils import (
convert_lark_to_gbnf, grammar_is_likely_lark, convert_lark_to_gbnf, grammar_is_likely_lark,
has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features) has_lmf_unsupported_json_features, has_xgrammar_unsupported_json_features)
from vllm.reasoning import ReasoningParserManager
if TYPE_CHECKING: if TYPE_CHECKING:
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
...@@ -79,12 +79,6 @@ def maybe_backend_fallback( ...@@ -79,12 +79,6 @@ def maybe_backend_fallback(
"xgrammar does not support Lark grammars and the " "xgrammar does not support Lark grammars and the "
"grammar failed to convert to GBNF.", "outlines") "grammar failed to convert to GBNF.", "outlines")
elif guided_params.json_object:
# https://github.com/mlc-ai/xgrammar/issues/256
fallback_or_error(guided_params,
"xgrammar does not support json_object.",
"guidance")
# If the xgrammar module cannot be imported successfully, # If the xgrammar module cannot be imported successfully,
# we should still allow users to use guided decoding with a fallback. # we should still allow users to use guided decoding with a fallback.
elif not xgr_installed: elif not xgr_installed:
...@@ -107,7 +101,11 @@ async def get_guided_decoding_logits_processor( ...@@ -107,7 +101,11 @@ async def get_guided_decoding_logits_processor(
model_config: ModelConfig, model_config: ModelConfig,
reasoning_backend: str | None = None) -> LogitsProcessor | None: reasoning_backend: str | None = None) -> LogitsProcessor | None:
reasoner = get_reasoner(tokenizer, reasoning_backend) reasoner = None
if reasoning_backend is not None:
reasoner_class = ReasoningParserManager.get_reasoning_parser(
reasoning_backend)
reasoner = reasoner_class(tokenizer)
guided_params = maybe_backend_fallback(guided_params) guided_params = maybe_backend_fallback(guided_params)
...@@ -146,8 +144,11 @@ def get_local_guided_decoding_logits_processor( ...@@ -146,8 +144,11 @@ def get_local_guided_decoding_logits_processor(
reasoning_backend: str | None = None) -> LogitsProcessor | None: reasoning_backend: str | None = None) -> LogitsProcessor | None:
guided_params = maybe_backend_fallback(guided_params) guided_params = maybe_backend_fallback(guided_params)
# Get the reasoner if needed, it will be None if reasoning_ reasoner = None
reasoner = get_reasoner(tokenizer, reasoning_backend) if reasoning_backend is not None:
reasoner_class = ReasoningParserManager.get_reasoning_parser(
reasoning_backend)
reasoner = reasoner_class(tokenizer)
# CFG grammar not supported by LMFE, so we use outlines instead # CFG grammar not supported by LMFE, so we use outlines instead
if guided_params.backend_name == 'outlines': if guided_params.backend_name == 'outlines':
......
...@@ -18,14 +18,22 @@ def get_local_guidance_guided_decoding_logits_processor( ...@@ -18,14 +18,22 @@ def get_local_guidance_guided_decoding_logits_processor(
""" """
grm = "" grm = ""
any_whitespace = 'disable-any-whitespace' not in \
guided_params.backend_options()
if guided_params.json: if guided_params.json:
grm = llguidance.LLMatcher.grammar_from_json_schema( grm = llguidance.LLMatcher.grammar_from_json_schema(
guided_params.json, guided_params.json,
overrides={"whitespace_pattern": guided_params.whitespace_pattern}) overrides={"whitespace_pattern": guided_params.whitespace_pattern},
defaults={
"whitespace_flexible": any_whitespace,
})
elif guided_params.json_object: elif guided_params.json_object:
grm = llguidance.LLMatcher.grammar_from_json_schema( grm = llguidance.LLMatcher.grammar_from_json_schema(
'{"type": "object"}', '{"type": "object"}',
overrides={"whitespace_pattern": guided_params.whitespace_pattern}) overrides={"whitespace_pattern": guided_params.whitespace_pattern},
defaults={
"whitespace_flexible": any_whitespace,
})
elif guided_params.regex: elif guided_params.regex:
grm = llguidance.grammar_from("regex", guided_params.regex) grm = llguidance.grammar_from("regex", guided_params.regex)
elif guided_params.choice: elif guided_params.choice:
......
...@@ -12,7 +12,7 @@ from transformers import PreTrainedTokenizerBase ...@@ -12,7 +12,7 @@ from transformers import PreTrainedTokenizerBase
from vllm.model_executor.guided_decoding.outlines_logits_processors import ( from vllm.model_executor.guided_decoding.outlines_logits_processors import (
CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor) CFGLogitsProcessor, JSONLogitsProcessor, RegexLogitsProcessor)
from vllm.model_executor.guided_decoding.reasoner import Reasoner from vllm.reasoning import ReasoningParser
from vllm.sampling_params import GuidedDecodingParams from vllm.sampling_params import GuidedDecodingParams
...@@ -61,7 +61,7 @@ _MAX_THREADPOOL_WORKERS = 16 ...@@ -61,7 +61,7 @@ _MAX_THREADPOOL_WORKERS = 16
async def get_outlines_guided_decoding_logits_processor( async def get_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams, guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner], reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]: None]:
""" """
...@@ -92,7 +92,7 @@ async def get_outlines_guided_decoding_logits_processor( ...@@ -92,7 +92,7 @@ async def get_outlines_guided_decoding_logits_processor(
def get_local_outlines_guided_decoding_logits_processor( def get_local_outlines_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams, guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner], reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor, ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor,
None]: None]:
""" """
...@@ -141,7 +141,7 @@ def _get_logits_processor( ...@@ -141,7 +141,7 @@ def _get_logits_processor(
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
mode: GuidedDecodingMode, mode: GuidedDecodingMode,
whitespace_pattern: Union[str, None], whitespace_pattern: Union[str, None],
reasoner: Optional[Reasoner], reasoner: Optional[ReasoningParser],
) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]: ) -> Union[JSONLogitsProcessor, RegexLogitsProcessor, CFGLogitsProcessor]:
if mode == GuidedDecodingMode.JSON: if mode == GuidedDecodingMode.JSON:
return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern, return JSONLogitsProcessor(guide, tokenizer, whitespace_pattern,
......
...@@ -34,8 +34,8 @@ from transformers import PreTrainedTokenizerBase ...@@ -34,8 +34,8 @@ from transformers import PreTrainedTokenizerBase
import vllm.envs as envs import vllm.envs as envs
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.guided_decoding.reasoner import Reasoner
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.reasoning import ReasoningParser
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -49,9 +49,9 @@ else: ...@@ -49,9 +49,9 @@ else:
class BaseLogitsProcessor: class BaseLogitsProcessor:
def __init__(self, guide: Guide, reasoner: Optional[Reasoner]): def __init__(self, guide: Guide, reasoner: Optional[ReasoningParser]):
self._guide: Guide = guide self._guide: Guide = guide
self._reasoner: Optional[Reasoner] = reasoner self._reasoner: Optional[ReasoningParser] = reasoner
# CFGState is used for the FSM state for CFGGuide # CFGState is used for the FSM state for CFGGuide
self._fsm_state: DefaultDict[int, Union[int, self._fsm_state: DefaultDict[int, Union[int,
CFGState]] = defaultdict(int) CFGState]] = defaultdict(int)
...@@ -69,7 +69,7 @@ class BaseLogitsProcessor: ...@@ -69,7 +69,7 @@ class BaseLogitsProcessor:
# Remove the reasoning tokens from the input_ids # Remove the reasoning tokens from the input_ids
# We need this because our implementation relies on the # We need this because our implementation relies on the
# hash of the input_ids to store the FSM state. # hash of the input_ids to store the FSM state.
input_ids = self._reasoner.extract_content(input_ids) input_ids = self._reasoner.extract_content_ids(input_ids)
seq_id = hash(tuple(input_ids)) seq_id = hash(tuple(input_ids))
...@@ -142,7 +142,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor): ...@@ -142,7 +142,7 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
self, self,
regex_string: str, regex_string: str,
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner], reasoner: Optional[ReasoningParser],
): ):
"""Compile the FSM that drives the regex-structured generation. """Compile the FSM that drives the regex-structured generation.
...@@ -163,7 +163,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor): ...@@ -163,7 +163,7 @@ class JSONLogitsProcessor(RegexLogitsProcessor):
def __init__(self, schema: Union[str, Dict, BaseModel], def __init__(self, schema: Union[str, Dict, BaseModel],
tokenizer: PreTrainedTokenizerBase, tokenizer: PreTrainedTokenizerBase,
whitespace_pattern: Union[str, None], whitespace_pattern: Union[str, None],
reasoner: Optional[Reasoner]): reasoner: Optional[ReasoningParser]):
"""Compile the FSM that drives the JSON-guided generation. """Compile the FSM that drives the JSON-guided generation.
Parameters Parameters
...@@ -203,7 +203,7 @@ class CFGLogitsProcessor(BaseLogitsProcessor): ...@@ -203,7 +203,7 @@ class CFGLogitsProcessor(BaseLogitsProcessor):
return CFGGuide(cfg, tokenizer) return CFGGuide(cfg, tokenizer)
def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase, def __init__(self, cfg: str, tokenizer: PreTrainedTokenizerBase,
reasoner: Optional[Reasoner]): reasoner: Optional[ReasoningParser]):
"""Compile the FSM that drives the context free grammar generation. """Compile the FSM that drives the context free grammar generation.
Parameters Parameters
......
...@@ -19,6 +19,10 @@ def get_reasoner(tokenizer: PreTrainedTokenizer, ...@@ -19,6 +19,10 @@ def get_reasoner(tokenizer: PreTrainedTokenizer,
return None return None
elif reasoning_backend == "deepseek_r1": elif reasoning_backend == "deepseek_r1":
return DeepSeekReasoner.from_tokenizer(tokenizer) return DeepSeekReasoner.from_tokenizer(tokenizer)
elif reasoning_backend == "granite":
logger.warning(
"Granite reasoner not yet implemented for structured outputs")
return None
else: else:
# Raise a warning for unknown reasoning backend and return None # Raise a warning for unknown reasoning backend and return None
# We cannot raise an error here because some reasoning models # We cannot raise an error here because some reasoning models
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from transformers import PreTrainedTokenizer
from vllm.model_executor.guided_decoding.reasoner.reasoner import Reasoner
@dataclass
class DeepSeekReasoner(Reasoner):
"""
Reasoner for DeepSeek R series models.
"""
start_token_id: int
end_token_id: int
start_token: str = "<think>"
end_token: str = "</think>"
@classmethod
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
return cls(start_token_id=tokenizer.encode(
"<think>", add_special_tokens=False)[0],
end_token_id=tokenizer.encode("</think>",
add_special_tokens=False)[0])
def is_reasoning_end(self, input_ids: list[int]) -> bool:
return self.end_token_id in input_ids
def extract_content(self, input_ids: list[int]) -> list[int]:
"""
Extract the content after the end tokens
"""
if self.end_token_id not in input_ids or \
input_ids.index(self.end_token_id) + 1 == len(input_ids):
return []
else:
return input_ids[input_ids.index(self.end_token_id) + 1:]
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass
from transformers import PreTrainedTokenizer
@dataclass
class Reasoner(ABC):
@abstractmethod
def from_tokenizer(cls, tokenizer: PreTrainedTokenizer) -> Reasoner:
pass
@abstractmethod
def is_reasoning_end(self, input_ids: list[int]) -> bool:
pass
@abstractmethod
def extract_content(self, input_ids: list[int]) -> list[int]:
pass
...@@ -27,7 +27,7 @@ if TYPE_CHECKING: ...@@ -27,7 +27,7 @@ if TYPE_CHECKING:
from transformers import PreTrainedTokenizer from transformers import PreTrainedTokenizer
from vllm.config import ModelConfig from vllm.config import ModelConfig
from vllm.model_executor.guided_decoding.reasoner import Reasoner from vllm.reasoning import ReasoningParser
from vllm.sampling_params import GuidedDecodingParams from vllm.sampling_params import GuidedDecodingParams
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -37,7 +37,7 @@ def get_local_xgrammar_guided_decoding_logits_processor( ...@@ -37,7 +37,7 @@ def get_local_xgrammar_guided_decoding_logits_processor(
guided_params: GuidedDecodingParams, guided_params: GuidedDecodingParams,
tokenizer: PreTrainedTokenizer, tokenizer: PreTrainedTokenizer,
model_config: ModelConfig, model_config: ModelConfig,
reasoner: Reasoner | None, reasoner: ReasoningParser | None,
max_threads: int = 8): max_threads: int = 8):
config = GrammarConfig.from_guided_params(guided_params=guided_params, config = GrammarConfig.from_guided_params(guided_params=guided_params,
model_config=model_config, model_config=model_config,
...@@ -280,7 +280,7 @@ class GrammarConfig: ...@@ -280,7 +280,7 @@ class GrammarConfig:
class XGrammarLogitsProcessor: class XGrammarLogitsProcessor:
"""Wrapper class to support pickle protocol""" """Wrapper class to support pickle protocol"""
config: GrammarConfig config: GrammarConfig
reasoner: Reasoner | None = None reasoner: ReasoningParser | None = None
ctx: xgr.CompiledGrammar | None = None ctx: xgr.CompiledGrammar | None = None
tokenizer_info: xgr.TokenizerInfo = None # type: ignore[assignment] tokenizer_info: xgr.TokenizerInfo = None # type: ignore[assignment]
...@@ -320,7 +320,10 @@ class XGrammarLogitsProcessor: ...@@ -320,7 +320,10 @@ class XGrammarLogitsProcessor:
elif self.config.grammar_str is not None: elif self.config.grammar_str is not None:
self.ctx = compiler.compile_grammar(self.config.grammar_str) self.ctx = compiler.compile_grammar(self.config.grammar_str)
elif self.config.json_object: elif self.config.json_object:
self.ctx = compiler.compile_builtin_json_grammar() any_whitespace = self.config.any_whitespace
self.ctx = compiler\
.compile_json_schema('{"type": "object"}',
any_whitespace=any_whitespace)
else: else:
raise ValueError( raise ValueError(
"Invalid configuration for xgrammar logits processor") "Invalid configuration for xgrammar logits processor")
......
...@@ -35,6 +35,8 @@ if HAS_TRITON: ...@@ -35,6 +35,8 @@ if HAS_TRITON:
# import to register the custom ops # import to register the custom ops
import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa import vllm.model_executor.layers.fused_moe.fused_marlin_moe # noqa
import vllm.model_executor.layers.fused_moe.fused_moe # noqa import vllm.model_executor.layers.fused_moe.fused_moe # noqa
from vllm.model_executor.layers.fused_moe.cutlass_moe import (
cutlass_moe_fp8)
from vllm.model_executor.layers.fused_moe.fused_moe import ( from vllm.model_executor.layers.fused_moe.fused_moe import (
fused_experts, fused_moe, fused_topk, get_config_file_name, fused_experts, fused_moe, fused_topk, get_config_file_name,
grouped_topk) grouped_topk)
...@@ -45,4 +47,5 @@ if HAS_TRITON: ...@@ -45,4 +47,5 @@ if HAS_TRITON:
"fused_experts", "fused_experts",
"get_config_file_name", "get_config_file_name",
"grouped_topk", "grouped_topk",
"cutlass_moe_fp8",
] ]
{
"1": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"2": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"4": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 16,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"8": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"16": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"24": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"32": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"48": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 1
},
"64": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 2,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"96": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"128": {
"BLOCK_SIZE_M": 16,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"256": {
"BLOCK_SIZE_M": 32,
"BLOCK_SIZE_N": 32,
"BLOCK_SIZE_K": 256,
"GROUP_SIZE_M": 8,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"512": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"1024": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"1536": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 8,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"2048": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 64,
"BLOCK_SIZE_K": 128,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"3072": {
"BLOCK_SIZE_M": 128,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 8,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
},
"4096": {
"BLOCK_SIZE_M": 64,
"BLOCK_SIZE_N": 128,
"BLOCK_SIZE_K": 64,
"GROUP_SIZE_M": 1,
"num_warps": 4,
"num_stages": 2,
"waves_per_eu": 0,
"matrix_instr_nonkdim": 16,
"kpack": 2
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment