Commit 539aa992 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge tag 'v0.6.2' into v0.6.2-dev

parents 93872128 7193774b
......@@ -120,9 +120,8 @@ class Fp8LinearMethod(LinearMethodBase):
# For GPUs that lack FP8 hardware support, we can leverage the Marlin
# kernel for fast weight-only FP8 quantization
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
self.use_marlin = capability < 89 or envs.VLLM_TEST_FORCE_FP8_MARLIN
self.use_marlin = (not current_platform.has_device_capability(89)
or envs.VLLM_TEST_FORCE_FP8_MARLIN)
# Disable marlin for rocm
if is_hip():
self.use_marlin = False
......
......@@ -55,7 +55,10 @@ class GGUFConfig(QuantizationConfig):
def _fuse_mul_mat(x: torch.Tensor, qweight: torch.Tensor,
qweight_type: int) -> torch.Tensor:
# use dequantize mulmat for IQmatrix, mmq for k-quants
if qweight_type >= 16:
if x.shape[0] == 1:
# enable mmvq in contiguous batching
y = ops.ggml_mul_mat_vec_a8(qweight, x, qweight_type, qweight.shape[0])
elif qweight_type >= 16:
block_size, type_size = gguf.GGML_QUANT_SIZES[qweight_type]
shape = (qweight.shape[0], qweight.shape[1] // type_size * block_size)
weight = ops.ggml_dequantize(qweight, qweight_type, *shape)
......
......@@ -217,6 +217,7 @@ class GPTQLinearMethod(LinearMethodBase):
layer.qzeros = Parameter(layer.qzeros.data, requires_grad=False)
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.g_idx = Parameter(layer.g_idx.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data, requires_grad=False)
# exllama needs to shuffle the weight after the weight is loaded
# here we do the shuffle on first forward pass
......
from typing import Any, Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Set, Union
import torch
from torch.nn import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
......@@ -11,12 +10,12 @@ from vllm.model_executor.layers.linear import (LinearBase, LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.layers.quantization.kernels import (
MPLinearLayerConfig, choose_mp_linear_kernel)
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
apply_gptq_marlin_linear, check_marlin_supported, marlin_is_k_full,
marlin_make_empty_g_idx, marlin_make_workspace, marlin_moe_permute_scales,
marlin_permute_scales, marlin_repeat_scales_on_all_ranks,
marlin_sort_g_idx, replace_tensor, verify_marlin_supported,
verify_marlin_supports_shape)
check_marlin_supported, marlin_moe_permute_scales,
marlin_repeat_scales_on_all_ranks, verify_marlin_supported)
from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead
from vllm.model_executor.parameter import (ChannelQuantScaleParameter,
GroupQuantScaleParameter,
......@@ -132,10 +131,10 @@ class GPTQMarlinConfig(QuantizationConfig):
def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
# Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits", None)
group_size = quant_config.get("group_size", None)
sym = quant_config.get("sym", None)
desc_act = quant_config.get("desc_act", None)
num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size")
sym = quant_config.get("sym")
desc_act = quant_config.get("desc_act")
if quant_method != "gptq":
return False
......@@ -159,6 +158,8 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
quant_config: The GPTQ Marlin quantization config.
"""
_kernel_backends_being_used: Set[str] = set()
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
self.quant_config = quant_config
......@@ -176,25 +177,34 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
del output_size
output_size_per_partition = sum(output_partition_sizes)
is_row_parallel = input_size != input_size_per_partition
weight_loader = extra_weight_attrs.get("weight_loader")
mp_linear_kernel_config = MPLinearLayerConfig(
full_weight_shape=(input_size, output_size),
partition_weight_shape=\
(input_size_per_partition, output_size_per_partition),
weight_type=self.quant_config.quant_type,
act_type=params_dtype,
group_size=self.quant_config.group_size,
zero_points=False,
has_g_idx=self.quant_config.desc_act
)
kernel_type = choose_mp_linear_kernel(mp_linear_kernel_config)
if kernel_type.__name__ not in self._kernel_backends_being_used:
logger.info("Using %s for GPTQMarlinLinearMethod",
kernel_type.__name__)
self._kernel_backends_being_used.add(kernel_type.__name__)
# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
verify_marlin_supports_shape(
output_size_per_partition=output_size_per_partition,
input_size_per_partition=input_size_per_partition,
input_size=input_size,
group_size=group_size,
)
# Determine sharding
if marlin_repeat_scales_on_all_ranks(self.quant_config.desc_act,
self.quant_config.group_size,
......@@ -275,57 +285,15 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
layer.register_parameter("g_idx", g_idx)
layer.register_parameter("scales", scales)
layer.register_parameter("qzeros", qzeros)
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.input_size = input_size
layer.is_k_full = marlin_is_k_full(self.quant_config.desc_act,
is_row_parallel)
# Checkpoints are serialized in AutoGPTQ format, which is different from the
# marlin format. This function is called after the weights are loaded.
# Here, we handle the repacking, including the activation reordering case.
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = layer.qweight.device
# required by torch.compile
layer.qweight = Parameter(layer.qweight.data, requires_grad=False)
layer.scales = Parameter(layer.scales.data, requires_grad=False)
self.kernel = kernel_type(mp_linear_kernel_config,
w_q_param_name="qweight",
w_s_param_name="scales",
w_zp_param_name="qzeros",
w_gidx_param_name="g_idx")
# Allocate marlin workspace
layer.workspace = marlin_make_workspace(
layer.output_size_per_partition, device)
# Handle sorting for activation reordering if needed.
if self.quant_config.desc_act:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(layer.g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
replace_tensor(layer, "g_idx", g_idx)
else:
layer.g_idx = marlin_make_empty_g_idx(device)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
# No zero-point
layer.zp = marlin_make_empty_g_idx(device)
# Repack weights from autogptq format to marlin format.
marlin_qweight = ops.gptq_marlin_repack(
layer.qweight,
perm=layer.g_idx_sort_indices,
size_k=layer.input_size_per_partition,
size_n=layer.output_size_per_partition,
num_bits=self.quant_config.quant_type.size_bits,
)
replace_tensor(layer, "qweight", marlin_qweight)
# Permute scales from autogptq format to marlin format.
marlin_scales = marlin_permute_scales(
layer.scales,
size_k=(layer.input_size if self.quant_config.desc_act else
layer.input_size_per_partition),
size_n=layer.output_size_per_partition,
group_size=self.quant_config.group_size,
)
replace_tensor(layer, "scales", marlin_scales)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
self.kernel.process_weights_after_loading(layer)
def apply(
self,
......@@ -333,20 +301,7 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return apply_gptq_marlin_linear(
input=x,
weight=layer.qweight,
weight_scale=layer.scales,
weight_zp=layer.zp,
g_idx=layer.g_idx,
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=layer.workspace,
wtype=self.quant_config.quant_type,
output_size_per_partition=layer.output_size_per_partition,
input_size_per_partition=layer.input_size_per_partition,
is_k_full=layer.is_k_full,
bias=bias,
)
return self.kernel.apply_weights(layer, x, bias)
class GPTQMarlinMoEMethod(FusedMoEMethodBase):
......@@ -506,12 +461,12 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
w13_g_idx_sort_indices[e]]
w2_sorted_g_idx[e] = layer.w2_g_idx[e][
w2_g_idx_sort_indices[e]]
replace_tensor(layer, "w13_g_idx", w13_sorted_g_idx)
replace_tensor(layer, "w2_g_idx", w2_sorted_g_idx)
replace_tensor(layer, "w13_g_idx_sort_indices",
w13_g_idx_sort_indices)
replace_tensor(layer, "w2_g_idx_sort_indices",
w2_g_idx_sort_indices)
replace_parameter(layer, "w13_g_idx", w13_sorted_g_idx)
replace_parameter(layer, "w2_g_idx", w2_sorted_g_idx)
replace_parameter(layer, "w13_g_idx_sort_indices",
w13_g_idx_sort_indices)
replace_parameter(layer, "w2_g_idx_sort_indices",
w2_g_idx_sort_indices)
else:
# Reset g_idx related tensors
num_experts = layer.w13_g_idx.shape[0]
......@@ -544,7 +499,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer.w13_qweight.shape[2],
self.quant_config.quant_type.size_bits,
)
replace_tensor(layer, "w13_qweight", marlin_w13_qweight)
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
layer.w2_qweight,
layer.w2_g_idx_sort_indices,
......@@ -552,7 +507,7 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
layer.w2_qweight.shape[2],
self.quant_config.quant_type.size_bits,
)
replace_tensor(layer, "w2_qweight", marlin_w2_qweight)
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
# Repack scales
marlin_w13_scales = marlin_moe_permute_scales(
s=layer.w13_scales,
......@@ -560,14 +515,14 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
size_n=layer.w13_scales.shape[2],
group_size=self.quant_config.group_size,
)
replace_tensor(layer, "w13_scales", marlin_w13_scales)
replace_parameter(layer, "w13_scales", marlin_w13_scales)
marlin_w2_scales = marlin_moe_permute_scales(
s=layer.w2_scales,
size_k=layer.w2_scales.shape[1] * self.quant_config.pack_factor,
size_n=layer.w2_scales.shape[2],
group_size=self.quant_config.group_size,
)
replace_tensor(layer, "w2_scales", marlin_w2_scales)
replace_parameter(layer, "w2_scales", marlin_w2_scales)
def apply(
self,
......@@ -611,4 +566,5 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
topk_ids,
w1_scale=layer.w13_scales,
w2_scale=layer.w2_scales,
num_bits=self.quant_config.quant_type.size_bits,
).to(orig_dtype)
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Optional, Tuple
import torch
from vllm.model_executor.layers.quantization.utils import replace_parameter
from vllm.scalar_type import ScalarType
@dataclass
class MPLinearLayerConfig:
full_weight_shape: Tuple[int, int] # [in, out]
partition_weight_shape: Tuple[int, int]
weight_type: ScalarType
act_type: torch.dtype
group_size: int
zero_points: bool
has_g_idx: bool
class MPLinearKernel(ABC):
@classmethod
@abstractmethod
def get_min_capability(cls) -> int:
raise NotImplementedError
@classmethod
@abstractmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
raise NotImplementedError
def __init__(self,
c: MPLinearLayerConfig,
w_q_param_name: str,
w_s_param_name: str,
w_zp_param_name: Optional[str] = None,
w_gidx_param_name: Optional[str] = None) -> None:
assert self.can_implement(c)
self.config = c
self.w_q_name = w_q_param_name
self.w_s_name = w_s_param_name
self.w_zp_name = w_zp_param_name
self.w_gidx_name = w_gidx_param_name
@abstractmethod
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
raise NotImplementedError
@abstractmethod
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
raise NotImplementedError
def _transform_param(self, layer: torch.nn.Module, name: Optional[str],
fn: Callable) -> None:
if name is not None and getattr(layer, name, None) is not None:
old_param = getattr(layer, name)
new_param = fn(old_param)
# replace the parameter with torch.nn.Parameter for TorchDynamo
# compatibility
replace_parameter(
layer, name,
torch.nn.Parameter(new_param.data, requires_grad=False))
def _get_weight_params(
self, layer: torch.nn.Module
) -> Tuple[torch.Tensor, # w_q
torch.Tensor, # w_s
Optional[torch.Tensor], # w_zp,
Optional[torch.Tensor] # w_gidx
]:
return (
getattr(layer, self.w_q_name),
getattr(layer, self.w_s_name),
getattr(layer, self.w_zp_name or "", None),
getattr(layer, self.w_gidx_name or "", None),
)
import os
from typing import List, Optional, Type
from vllm.model_executor.layers.quantization.kernels.machete import (
MacheteLinearKernel)
from vllm.model_executor.layers.quantization.kernels.marlin import (
MarlinLinearKernel)
from vllm.model_executor.layers.quantization.kernels.MPLinearKernel import (
MPLinearKernel, MPLinearLayerConfig)
from vllm.platforms import current_platform
# in priority/performance order (when available)
_POSSIBLE_KERNELS: List[Type[MPLinearKernel]] = [
MacheteLinearKernel,
MarlinLinearKernel,
]
def choose_mp_linear_kernel(
config: MPLinearLayerConfig,
compute_capability: Optional[int] = None) -> Type[MPLinearKernel]:
"""
Choose an MPLinearKernel that can implement the given config for the given
compute capability. Attempts to choose the best kernel in terms of
performance.
Args:
config (MPLinearLayerConfig): Description of the linear layer to be
implemented.
compute_capability (Optional[int], optional): The compute capability of
the target device, if None uses `current_platform` to get the compute
capability. Defaults to None.
Raises:
ValueError: If no kernel can implement the given config.
Returns:
Type[MPLinearKernel]: Chosen kernel.
"""
if compute_capability is None:
if current_platform is None:
raise ValueError("Cannot determine compute capability")
_cc = current_platform.get_device_capability()
compute_capability = _cc[0] * 10 + _cc[1]
failure_reasons = []
for kernel in _POSSIBLE_KERNELS:
if kernel.__name__ in os.environ.get("VLLM_DISABLED_KERNELS", "")\
.split(","):
failure_reasons.append(
f' {kernel.__name__} disabled by environment variable')
continue
if kernel.get_min_capability() > compute_capability:
failure_reasons.append(
f"{kernel.__name__} requires capability "
f"{kernel.get_min_capability()}, current compute capability "
f"is {compute_capability}")
continue
can_implement, failure_reason = kernel.can_implement(config)
if can_implement:
return kernel
else:
failure_reasons.append(
f' {kernel.__name__} cannot implement due to: {failure_reason}'
)
raise ValueError(
"Failed to find a kernel that can implement the "\
"WNA16 linear layer. Reasons: \n"
+ '\n'.join(failure_reasons))
from functools import partial
from typing import Optional, Tuple
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.machete_utils import (
MACHETE_SUPPORTED_GROUP_SIZES, check_machete_supports_shape,
query_machete_supported_quant_types)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
pack_weights_into_int32, unpack_weights_into_int32)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
class MacheteLinearKernel(MPLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 90
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if c.has_g_idx and\
c.partition_weight_shape[0] != c.full_weight_shape[0]:
return False, "Act reordering currently not supported by Machete, "\
"when the input features are partitioned across "\
"devices"
if c.zero_points:
return False, "Zero points currently not supported by "\
" Compressed Tensors + Machete. (Kernel supports it"\
" but CompressedTensorsWNA16 does not so support has"\
" not been added to MacheteWNA16Kernel yet"
if c.weight_type not in query_machete_supported_quant_types(
c.zero_points):
return False, f"Quant type ({c.weight_type}) not supported by "\
"Machete, supported types are: "\
f"{query_machete_supported_quant_types(c.zero_points)}"
if c.group_size not in MACHETE_SUPPORTED_GROUP_SIZES:
return False, f"Group size ({c.group_size}) not supported by "\
"Machete, supported group sizes are: "\
f"{MACHETE_SUPPORTED_GROUP_SIZES}"
return check_machete_supports_shape(c.partition_weight_shape[0],
c.partition_weight_shape[1])
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module):
c = self.config
if c.has_g_idx:
assert self.w_gidx_name is not None
perm = torch.argsort(getattr(layer, self.w_gidx_name))\
.to(torch.int)
self.act_perm = lambda x: x[:, perm]
# use `ops.permute_cols` if possible
if c.act_type in [torch.float16, torch.bfloat16] \
and c.partition_weight_shape[0] % 8 == 0:
self.act_perm = partial(ops.permute_cols, perm=perm)
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
if c.has_g_idx:
x_unpacked = unpack_weights_into_int32(x.data,
c.weight_type,
packed_dim=0)
x_perm = x_unpacked[perm, :]
x.data = pack_weights_into_int32(x_perm,
c.weight_type,
packed_dim=0)
x.data = ops.machete_prepack_B(x.data.t().contiguous().t(),
self.config.weight_type)
return x
def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = x.data.contiguous()
return x
# Repack weights and scales for Machete
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
c = self.config
w_q, w_s, _, _ = self._get_weight_params(layer)
x_2d = x.reshape(-1, x.shape[-1])
out_shape = x.shape[:-1] + (c.partition_weight_shape[1], )
if c.has_g_idx:
x_2d = self.act_perm(x_2d)
output = ops.machete_gemm(a=x_2d,
b_q=w_q,
b_type=c.weight_type,
b_zeros=None,
b_scales=w_s,
b_group_size=c.group_size)
if bias is not None:
output.add_(bias) # In-place add
return output.reshape(out_shape)
from typing import Optional, Tuple
import torch
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
MARLIN_SUPPORTED_GROUP_SIZES, apply_gptq_marlin_linear,
check_marlin_supports_shape, marlin_is_k_full, marlin_make_empty_g_idx,
marlin_make_workspace, marlin_permute_scales, marlin_sort_g_idx,
query_marlin_supported_quant_types)
from vllm.model_executor.parameter import (BasevLLMParameter,
permute_param_layout_)
from .MPLinearKernel import MPLinearKernel, MPLinearLayerConfig
class MarlinLinearKernel(MPLinearKernel):
@classmethod
def get_min_capability(cls) -> int:
return 80
@classmethod
def can_implement(cls,
c: MPLinearLayerConfig) -> Tuple[bool, Optional[str]]:
if c.zero_points:
return False, "Zero points currently not supported by "\
" MarlinLinearKernel. Will be added when AWQMarlin "\
"is migrated over to using MPLinearKernel backend"
quant_types = query_marlin_supported_quant_types(c.zero_points)
if c.weight_type not in quant_types:
return False, f"Quant type ({c.weight_type}) not supported by"\
f" Marlin, supported types are: {quant_types}"
if c.group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
return False, f"Group size ({c.group_size}) not supported by "\
"Marlin, supported group sizes are: "\
f"{MARLIN_SUPPORTED_GROUP_SIZES}"
return check_marlin_supports_shape(c.partition_weight_shape[0],
c.partition_weight_shape[1],
c.full_weight_shape[1],
c.group_size)
# note assumes that
# `weight_packed` is: {input_dim = 0, output_dim = 1, packed_dim = 0}
# `weight_scale` is: {input_dim = 0, output_dim = 1}
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
device = getattr(layer, self.w_q_name).device
c = self.config
row_parallel = (c.partition_weight_shape[0] != c.full_weight_shape[0])
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
# Allocate marlin workspace.
self.workspace = marlin_make_workspace(c.partition_weight_shape[1],
device)
# Default names since marlin requires empty parameters for these,
# TODO: remove this requirement from marlin (allow optional tensors)
if self.w_gidx_name is None:
self.w_gidx_name = "g_idx"
if self.w_zp_name is None:
self.w_zp_name = "w_zp"
if c.has_g_idx:
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
getattr(layer, self.w_gidx_name))
self._transform_param(layer, self.w_gidx_name, lambda _: g_idx)
layer.g_idx_sort_indices = g_idx_sort_indices
else:
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
if c.zero_points:
pass
# TODO (lucas): add the following when AWQMarlin is migrated over to
# using MPLinearKernel backend
# self._transform_param(layer, self.w_zp_name, lambda x: \
# marlin_zero_points(
# x,
# size_k=c.partition_weight_shape[0],
# size_n=c.partition_weight_shape[1],
# num_bits=c.weight_type.size_bits))
else:
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
def transform_w_q(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
x.data = ops.gptq_marlin_repack(x.data.contiguous(),
perm=layer.g_idx_sort_indices,
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
num_bits=c.weight_type.size_bits)
return x
def transform_w_s(x):
assert isinstance(x, BasevLLMParameter)
permute_param_layout_(x, input_dim=0, output_dim=1)
x.data = marlin_permute_scales(x.data.contiguous(),
size_k=c.partition_weight_shape[0],
size_n=c.partition_weight_shape[1],
group_size=c.group_size)
return x
self._transform_param(layer, self.w_q_name, transform_w_q)
self._transform_param(layer, self.w_s_name, transform_w_s)
def apply_weights(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
c = self.config
w_q, w_s, w_zp, w_gidx = self._get_weight_params(layer)
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
# None for marlin
return apply_gptq_marlin_linear(
input=x,
weight=w_q,
weight_scale=w_s,
weight_zp=w_zp, # type: ignore
g_idx=w_gidx, # type: ignore
g_idx_sort_indices=layer.g_idx_sort_indices,
workspace=self.workspace,
wtype=c.weight_type,
input_size_per_partition=c.partition_weight_shape[0],
output_size_per_partition=c.partition_weight_shape[1],
is_k_full=self.is_k_full,
bias=bias)
......@@ -260,7 +260,7 @@ class QQQLinearMethod(LinearMethodBase):
size_k = x_2d.shape[1]
size_n = s_ch.shape[1]
x_int8, s_tok = ops.scaled_int8_quant(x_2d)
x_int8, s_tok, _ = ops.scaled_int8_quant(x_2d)
output_2d = ops.marlin_qqq_gemm(x_int8, qweight, s_tok, s_ch, s_group,
workspace, size_m, size_n, size_k)
......
from .layer_utils import replace_parameter, update_tensor_inplace
__all__ = ['update_tensor_inplace', 'replace_parameter']
from typing import Union
import torch
def update_tensor_inplace(dst: torch.Tensor, src: torch.Tensor):
assert dst.dtype == src.dtype, "Tensors must have the same dtype"
# update tensor shape and stride
dst.as_strided_(src.shape, src.stride())
# If not the same underlying storage move tensor data
if dst.data_ptr() != src.data_ptr():
dst.copy_(src)
del src
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def replace_parameter(mod: torch.nn.Module, name: str,
new: Union[torch.Tensor, torch.nn.Parameter]) -> None:
old = getattr(mod, name)
if type(old) is type(new) and old.dtype == new.dtype and \
old.untyped_storage().nbytes() == new.untyped_storage().nbytes():
# If we can just update in-place to avoid re-registering
# can be faster if the underlying storage is the same
update_tensor_inplace(old, new)
else:
# Fallback re-register parameter, convert to Parameter if necessary
# this not only ensures we don't register a tensor as a parameter, but
# also ensures that all parameter subclasses get re-registered as
# parameters for `torch.compile` compatibility
if not isinstance(new, torch.nn.Parameter):
new = torch.nn.Parameter(new, requires_grad=False)
mod.register_parameter(name,
torch.nn.Parameter(new, requires_grad=False))
from typing import List, Optional, Tuple
import torch
from vllm.scalar_type import ScalarType, scalar_types
MACHETE_SUPPORTED_GROUP_SIZES = [-1, 128]
MACHETE_PREPACKED_BLOCK_SHAPE = [64, 128]
def query_machete_supported_quant_types(zero_points: bool) -> List[ScalarType]:
if zero_points:
return [scalar_types.uint4, scalar_types.uint8]
else:
return [scalar_types.uint4b8, scalar_types.uint8b128]
def query_machete_supported_act_types(zero_points: bool) -> List[ScalarType]:
return [torch.float16, torch.bfloat16]
def check_machete_supports_shape(in_features: int, out_featrues: int) \
-> Tuple[bool, Optional[str]]:
if in_features % MACHETE_PREPACKED_BLOCK_SHAPE[0] != 0:
return False, "Input features size must be divisible by "\
f"{MACHETE_PREPACKED_BLOCK_SHAPE[0]}"
if out_featrues % MACHETE_PREPACKED_BLOCK_SHAPE[1] != 0:
return False, "Output features size must be divisible by "\
f"{MACHETE_PREPACKED_BLOCK_SHAPE[1]}"
return True, None
......@@ -29,8 +29,9 @@ def query_marlin_supported_quant_types(has_zp: bool,
device_capability: Optional[int] = None
):
if device_capability is None:
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
capability_tuple = current_platform.get_device_capability()
device_capability = (-1 if capability_tuple is None else
capability_tuple.to_int())
if device_capability < 80:
return []
......@@ -52,8 +53,9 @@ def _check_marlin_supported(
device_capability: Optional[int] = None) -> Tuple[bool, Optional[str]]:
if device_capability is None:
major, minor = current_platform.get_device_capability()
device_capability = major * 10 + minor
capability_tuple = current_platform.get_device_capability()
device_capability = (-1 if capability_tuple is None else
capability_tuple.to_int())
supported_types = query_marlin_supported_quant_types(
has_zp, device_capability)
......@@ -118,6 +120,19 @@ def verify_marlin_supports_shape(output_size_per_partition: int,
"with --quantization gptq.")
def check_marlin_supports_shape(output_size_per_partition: int,
input_size_per_partition: int,
input_size: int, group_size: int) \
-> Tuple[bool, Optional[str]]:
try:
verify_marlin_supports_shape(output_size_per_partition,
input_size_per_partition, input_size,
group_size)
except ValueError as e:
return False, e.__str__()
return True, None
def marlin_make_workspace(output_size_per_partition: int,
device: torch.device) -> torch.Tensor:
max_workspace_size = (output_size_per_partition //
......@@ -146,6 +161,11 @@ def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
requires_grad=False)
def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
return torch.nn.Parameter(torch.empty(0, dtype=torch.int, device=device),
requires_grad=False)
def marlin_sort_g_idx(
g_idx: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
......@@ -238,17 +258,6 @@ def awq_to_marlin_zero_points(q_zp_packed: torch.Tensor, size_k: int,
return marlin_zp
# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def replace_tensor(layer: torch.nn.Module, name: str,
new_t: torch.Tensor) -> None:
# It is important to use resize_() here since it ensures
# the same buffer is reused
getattr(layer, name).resize_(new_t.shape)
getattr(layer, name).copy_(new_t)
del new_t
def apply_gptq_marlin_linear(
input: torch.Tensor,
weight: torch.Tensor,
......
......@@ -10,8 +10,7 @@ from .marlin_utils import marlin_make_workspace, marlin_permute_scales
def is_fp8_marlin_supported():
capability = current_platform.get_device_capability()
return capability[0] >= 8
return current_platform.has_device_capability(80)
def apply_fp8_marlin_linear(
......
......@@ -20,6 +20,49 @@ FUSED_LAYER_NAME_MAPPING = {
}
def pack_weights_into_int32(w_q: torch.Tensor,
wtype: ScalarType,
packed_dim: int = 0):
# move dim to pack to the end
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
w_q_perm = w_q.permute(perm)
pack_factor = 32 // wtype.size_bits
mask = (1 << wtype.size_bits) - 1
new_shape_perm = list(w_q_perm.shape)
assert w_q_perm.shape[-1] % pack_factor == 0
new_shape_perm[-1] //= pack_factor
res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
for i in range(pack_factor):
res |= (w_q_perm[..., i::pack_factor] & mask) << wtype.size_bits * i
return res.permute(inv_perm)
def unpack_weights_into_int32(w_q: torch.Tensor,
wtype: ScalarType,
packed_dim: int = 0):
# move dim to pack to the end
perm = (*[i for i in range(len(w_q.shape)) if i != packed_dim], packed_dim)
inv_perm = tuple(perm.index(i) for i in range(len(perm)))
w_q_perm = w_q.permute(perm)
pack_factor = 32 // wtype.size_bits
mask = (1 << wtype.size_bits) - 1
new_shape_perm = list(w_q_perm.shape)
new_shape_perm[-1] *= pack_factor
res = torch.zeros(new_shape_perm, dtype=torch.int32, device=w_q.device)
for i in range(pack_factor):
res[..., i::pack_factor] = (w_q_perm >> wtype.size_bits * i) & mask
return res.permute(inv_perm)
def is_layer_skipped(prefix: str, ignored_layers: List[str]) -> bool:
# prefix: model.layers.0.self_attn.q_proj
# proj_name: q_proj
......
......@@ -6,19 +6,18 @@ from vllm import _custom_ops as ops
from vllm.platforms import current_platform
from vllm.utils import is_hip
# scaled_mm in pytorch on rocm has a bug that requires always
# providing scaling factor for result. This value is created
# as global value to avoid multiple tensor allocations, and
# can be removed once pytorch fixes the bug.
TORCH_SCALED_MM_SCALE_RESULT = torch.ones(1).cuda() if is_hip() else None
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY = torch.ones(1).cuda() if is_hip() else None
def cutlass_fp8_supported() -> bool:
# cutlass is not supported on Rocm
if is_hip():
return False
capability = current_platform.get_device_capability()
capability = capability[0] * 10 + capability[1]
capability_tuple = current_platform.get_device_capability()
capability = -1 if capability_tuple is None else capability_tuple.to_int()
return ops.cutlass_scaled_mm_supports_fp8(capability)
......@@ -130,19 +129,17 @@ def apply_fp8_linear(
if per_tensor_weights and per_tensor_activations:
# Fused GEMM_DQ
output = torch._scaled_mm(
qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
scale_result=TORCH_SCALED_MM_SCALE_RESULT,
bias=bias)
# Since in torch 2.5, scaled_mm only returns single value
# This should be removed when vllm-nvidia also moves to 2.5
if is_hip():
return torch.narrow(output, 0, 0, input.shape[0])
return torch.narrow(output[0], 0, 0, input.shape[0])
output = torch._scaled_mm(qinput,
weight,
out_dtype=input.dtype,
scale_a=x_scale,
scale_b=weight_scale,
bias=bias)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
return torch.narrow(output[0], 0, 0, input.shape[0])
return torch.narrow(output, 0, 0, input.shape[0])
else:
# Fallback for channelwise case, where we use unfused DQ
......@@ -160,12 +157,23 @@ def apply_fp8_linear(
# For the scaled_mm fallback case, we break this down, since it
# does not support s_w being a vector.
# Making sure the dummy tensor is on the same device as the weight
global TORCH_DEVICE_IDENTITY
if TORCH_DEVICE_IDENTITY.device != weight.device:
TORCH_DEVICE_IDENTITY = TORCH_DEVICE_IDENTITY.to(weight.device)
# GEMM
# This computes C = (X * W).
# Output in fp32 to allow subsequent ops to happen in-place
output, _ = torch._scaled_mm(qinput,
weight,
out_dtype=torch.float32)
output = torch._scaled_mm(qinput,
weight,
scale_a=TORCH_DEVICE_IDENTITY,
scale_b=TORCH_DEVICE_IDENTITY,
out_dtype=torch.float32)
# A fix for discrepancy in scaled_mm which returns tuple
# for torch < 2.5 and a single value in torch >= 2.5
if type(output) is tuple and len(output) == 2:
output = output[0]
# Unpad (undo num_token_padding)
output = torch.narrow(output, 0, 0, input.shape[0])
x_scale = torch.narrow(x_scale, 0, 0, input.shape[0])
......@@ -188,7 +196,7 @@ def apply_int8_linear(
# ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
x_q, x_scale = ops.scaled_int8_quant(input, input_scale)
x_q, x_scale, _ = ops.scaled_int8_quant(input, input_scale)
return ops.cutlass_scaled_mm(x_q,
weight,
......
......@@ -31,15 +31,11 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
"""
def __init__(self,
disable_bonus_tokens: bool = True,
strict_mode: bool = False,
use_flashinfer: Optional[bool] = None):
"""Create a rejection sampler.
Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
......@@ -48,8 +44,7 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
None, we will use the default value from the environment variable.
This parameter is only used for testing purposes.
"""
super().__init__(disable_bonus_tokens=disable_bonus_tokens,
strict_mode=strict_mode)
super().__init__(strict_mode=strict_mode)
if use_flashinfer is None:
self.use_flashinfer = envs.VLLM_USE_FLASHINFER_SAMPLER and (
chain_speculative_sampling is not None)
......@@ -57,8 +52,6 @@ class RejectionSampler(SpecDecodeStochasticBaseSampler):
self.use_flashinfer = use_flashinfer
if self.use_flashinfer:
assert not disable_bonus_tokens, \
"flashinfer will enable bonus token by default"
logger.info("Use flashinfer for rejection sampling.")
else:
logger.info("Use pytorch for rejection sampling.")
......
......@@ -10,19 +10,15 @@ import msgspec
import torch
import torch.nn as nn
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
from vllm.triton_utils import HAS_TRITON
if HAS_TRITON:
from vllm.model_executor.layers.ops.sample import sample as sample_triton
import vllm.envs as envs
from vllm.model_executor.sampling_metadata import (SamplingMetadata,
SamplingTensors,
SequenceGroupToSample)
from vllm.sampling_params import SamplingType
from vllm.sequence import (CompletionSequenceGroupOutput, Logprob,
from vllm.sequence import (VLLM_INVALID_TOKEN_ID,
CompletionSequenceGroupOutput, Logprob,
PromptLogprobs, SampleLogprobs, SequenceOutput)
from vllm.spec_decode.metrics import SpecDecodeWorkerMetrics
if envs.VLLM_USE_FLASHINFER_SAMPLER and find_spec("flashinfer"):
import flashinfer.sampling
......@@ -438,12 +434,9 @@ def _apply_top_k_top_p(
logits_sort.masked_fill_(top_p_mask, -float("inf"))
# Re-sort the probabilities.
src = torch.arange(logits_idx.shape[-1],
device=logits_idx.device).expand_as(logits_idx)
logits_idx_inv = torch.empty_like(logits_idx).scatter_(dim=-1,
index=logits_idx,
src=src)
logits = torch.gather(logits_sort, dim=-1, index=logits_idx_inv)
logits = torch.empty_like(logits_sort).scatter_(dim=-1,
index=logits_idx,
src=logits_sort)
return logits
......@@ -740,7 +733,7 @@ def _sample_with_torch(
) -> SampleReturnType:
'''Torch-oriented _sample() implementation.
Single-step scheduling:
Single-step scheduling:
* Perform GPU-side sampling computation
* Immediately Pythonize sampling result
......@@ -767,17 +760,17 @@ def _sample_with_torch(
# Create output tensor for sampled token ids.
if include_gpu_probs_tensor:
sampled_token_ids_tensor = torch.empty(logprobs.shape[0],
1,
dtype=torch.long,
device=logprobs.device)
sampled_token_ids_tensor = torch.full((logprobs.shape[0], 1),
VLLM_INVALID_TOKEN_ID,
dtype=torch.long,
device=logprobs.device)
else:
sampled_token_ids_tensor = None
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
for sampling_type in SamplingType:
sample_indices = categorized_sample_indices[sampling_type][:, 0]
sample_indices = categorized_sample_indices[sampling_type]
num_tokens = len(sample_indices)
if num_tokens == 0:
continue
......@@ -863,88 +856,6 @@ def _sample_with_torch(
)
def _sample_with_triton_kernel(
probs: torch.Tensor,
logprobs: torch.Tensor,
sampling_metadata: SamplingMetadata,
sampling_tensors: SamplingTensors,
) -> SampleResultType:
categorized_seq_group_ids: Dict[SamplingType,
List[int]] = {t: []
for t in SamplingType}
categorized_sample_indices = sampling_metadata.categorized_sample_indices
for i, seq_group in enumerate(sampling_metadata.seq_groups):
sampling_params = seq_group.sampling_params
sampling_type = sampling_params.sampling_type
categorized_seq_group_ids[sampling_type].append(i)
sample_results_dict: Dict[int, Tuple[List[int], List[int]]] = {}
sample_metadata: Dict[SamplingType,
Tuple[List[int], List[SequenceGroupToSample],
torch.Tensor, torch.Tensor]] = {}
max_best_of_in_batch = 1
# Counterintiutively, having two loops here is actually faster.
# The first loop can run without waiting on GPU<->CPU sync.
for sampling_type in SamplingType:
sample_indices = categorized_sample_indices[sampling_type][:, 0]
sampled_token_indices = categorized_sample_indices[sampling_type][:, 1]
num_tokens = len(sample_indices)
if num_tokens == 0:
continue
seq_group_id = categorized_seq_group_ids[sampling_type]
seq_groups = [sampling_metadata.seq_groups[i] for i in seq_group_id]
sample_metadata[sampling_type] = (seq_group_id, seq_groups,
sample_indices,
sampled_token_indices)
if sampling_type in (SamplingType.GREEDY, SamplingType.RANDOM,
SamplingType.RANDOM_SEED):
for seq_group in seq_groups:
if seq_group.is_prompt:
sampling_params = seq_group.sampling_params
max_best_of_in_batch = max(max_best_of_in_batch,
sampling_params.best_of)
elif sampling_type == SamplingType.BEAM:
beam_search_logprobs = logprobs[sample_indices]
else:
raise ValueError(f"Unsupported sampling type: {sampling_type}")
sampled_tokens, _, _ = sample_triton(
probs=probs,
seeds=sampling_tensors.sampling_seeds,
max_best_of=max_best_of_in_batch,
sample_indices=sampling_tensors.sample_indices,
logprobs=logprobs,
# don't save logprobs because we have logic for that below
# TODO: use this instead of the CPU-based logic below
save_logprobs=False,
)
# GPU<->CPU sync happens in the loop below.
for sampling_type in SamplingType:
if sampling_type not in sample_metadata:
continue
(seq_group_id, seq_groups, sample_indices,
sampled_token_indices) = sample_metadata[sampling_type]
if sampling_type == SamplingType.GREEDY:
sample_results = _greedy_sample(
seq_groups, sampled_tokens[sampled_token_indices][:, 0])
elif sampling_type in (SamplingType.RANDOM, SamplingType.RANDOM_SEED):
sample_results = _random_sample(
seq_groups, sampled_tokens[sampled_token_indices])
elif sampling_type == SamplingType.BEAM:
sample_results = _beam_search_sample(seq_groups,
beam_search_logprobs)
sample_results_dict.update(zip(seq_group_id, sample_results))
sample_results = [
sample_results_dict.get(i, ([], []))
for i in range(len(sampling_metadata.seq_groups))
]
return sample_results
def _sample(
probs: torch.Tensor,
logprobs: torch.Tensor,
......@@ -974,10 +885,6 @@ def _sample(
modify_greedy_probs=modify_greedy_probs,
)
# TODO: Enable once Triton kernel & associated code is faster.
# return _sample_with_triton_kernel(probs, logprobs, sampling_metadata,
# sampling_tensors)
def _get_ranks(x: torch.Tensor, indices: torch.Tensor) -> torch.Tensor:
"""
......
......@@ -11,20 +11,14 @@ class SpecDecodeBaseSampler(nn.Module):
step.
"""
def __init__(self,
disable_bonus_tokens: bool = True,
strict_mode: bool = False):
def __init__(self, strict_mode: bool = False):
"""Base class constructor.
Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
"""
super().__init__()
self._disable_bonus_tokens = disable_bonus_tokens
self._strict_mode = strict_mode
# NOTE: A "bonus token" is accepted iff all proposal tokens are
......@@ -111,13 +105,6 @@ class SpecDecodeBaseSampler(nn.Module):
output_with_bonus_tokens[:, -1] = torch.where(output[:, -1] != -1,
bonus_token_ids, -1)
# We disable bonus tokens because it causes corrupt KV cache for
# proposal methods that require KV cache. We can fix it by "prefilling"
# the bonus token in the proposer. The following issue tracks the fix.
# https://github.com/vllm-project/vllm/issues/4212
if self._disable_bonus_tokens:
output_with_bonus_tokens[:, -1] = -1
# Fill the recovered token ids.
output.mul_(~after_false_mask).add_(
substitute_token_ids.mul(after_false_mask))
......
......@@ -16,15 +16,11 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
self,
posterior_threshold: float,
posterior_alpha: float,
disable_bonus_tokens: bool = False,
strict_mode: bool = False,
):
"""Create a Typical Acceptance Sampler.
Args:
disable_bonus_tokens: Whether or not to disable the bonus token.
Require when bonus tokens will cause corrupt KV cache for
proposal methods that require KV cache.
strict_mode: Whether or not to perform shape/device/dtype checks
during sampling. This catches correctness issues but adds
nontrivial latency.
......@@ -36,8 +32,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
"""
self._posterior_threshold = posterior_threshold
self._posterior_alpha = posterior_alpha
super().__init__(disable_bonus_tokens=disable_bonus_tokens,
strict_mode=strict_mode)
super().__init__(strict_mode=strict_mode)
def forward(
self,
......@@ -54,7 +49,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
one token will be emitted.
In the case where all draft tokens are accepted, the bonus token will be
accepted conditioned on self._disable_bonus_tokens being false.
accepted.
Args:
target_probs: The probability distribution over token ids given
......@@ -85,7 +80,7 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
target_probs = target_with_bonus_probs[:, :-1]
accepted = self._evaluate_accepted_tokens(target_probs,
draft_token_ids)
recovered_token_ids = self._replacement_token_ids(target_probs)
recovered_token_ids = self._get_recovered_token_ids(target_probs)
output_token_ids = self._create_output(accepted, recovered_token_ids,
draft_token_ids,
bonus_token_ids)
......@@ -153,16 +148,10 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
accepted_mask = candidates_prob > threshold
return accepted_mask
def _replacement_token_ids(self, target_probs):
def _get_recovered_token_ids(self, target_probs):
"""
Generate one replacement token ID for each sequence based on target
probabilities. The replacement token is used as the fallback option
if typical acceptance sampling does not accept any draft tokens for
that particular sequence.
This method computes the token IDs to be replaced by selecting the
token with the highest probability for each sequence in the first
position. The rest of the output is filled with -1.
The recovered token ids will fill the first unmatched token
by the target token.
Parameters
----------
......@@ -173,13 +162,9 @@ class TypicalAcceptanceSampler(SpecDecodeDeterministicBaseSampler):
Returns
-------
torch.Tensor
A tensor of shape (batch_size, k) with the replacement
token IDs. Only the first column is set, and the rest of the
columns are filled with -1.
A tensor of shape (batch_size, k) with the recovered token
ids which are selected from target probs.
"""
max_indices = torch.argmax(target_probs[:, 0, :], dim=1)
output = -torch.ones((target_probs.shape[0], target_probs.shape[1]),
dtype=self.token_id_dtype,
device=target_probs.device)
output[:, 0] = max_indices
return output
max_indices = torch.argmax(target_probs, dim=-1)
return max_indices
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment