Commit 4a943d35 authored by 王敏's avatar 王敏
Browse files

[feat]1.w8a8 marlin适配deepep低延迟;2.非naive ep模式,去掉多余的dp padding,避免allreduce耗时

parent b956fc64
......@@ -4,7 +4,6 @@
import enum
from enum import Enum
from typing import Callable, Optional
from math import prod
import torch
from compressed_tensors import CompressionFormat
......@@ -14,8 +13,6 @@ from compressed_tensors.quantization import (ActivationOrdering,
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.config import get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size, get_ep_group
from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoEActivationFormat, FusedMoEConfig, FusedMoEMethodBase,
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,
......@@ -35,16 +32,10 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_emulation_utils import
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
all_close_1d, normalize_e4m3fn_to_e4m3fnuz, per_tensor_dequantize)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils import W8a8GetCacheJSON
try:
from lightop import m_grouped_w8a8_gemm_nt_masked, fuse_silu_mul_quant_ep
except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
logger = init_logger(__name__)
......@@ -1009,27 +1000,11 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
"dynamic per token quantization. Found static input scales.")
self.tritonsingleton= W8a8GetCacheJSON()
self.fused_experts = self.fused_moe_forward
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
self.use_deepep = parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
if self.use_deepep:
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
self.num_dispatchers = all2all_manager.world_size
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
if self.use_deepep:
self.N = 2 * intermediate_size_per_partition
self.K = hidden_size
params_dtype = torch.int8
# WEIGHTS
......@@ -1127,6 +1102,8 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
"EPLB not supported for "
"`CompressedTensorsW8A8Int8MoEMethod` yet.")
from vllm.model_executor.layers.fused_moe import fused_experts
topk_weights, topk_ids = FusedMoE.select_experts(
hidden_states=x,
router_logits=router_logits,
......@@ -1139,53 +1116,12 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate,
e_score_correction_bias=e_score_correction_bias,
indices_type=torch.int64 if self.use_deepep else None,)
return self.fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor,
)
e_score_correction_bias=e_score_correction_bias)
def fused_moe_forward(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
**_ ):
from vllm.model_executor.layers.fused_moe import fused_experts
return fused_experts(
hidden_states=x,
w1=w1,
w2=w2,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
......@@ -1195,149 +1131,14 @@ class CompressedTensorsW8A8Int8MoEMethod(CompressedTensorsMoEMethod):
per_channel_quant=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=False,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor)
def groupgemm_workspace_shapes(self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,):
assert a.dim() == 2
# FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed.
num_dispatchers = self.num_dispatchers
num_experts = local_num_experts
max_num_tokens = a.size(
0) if self.max_num_tokens_per_rank is None else self.max_num_tokens_per_rank
workspace13 = (num_experts, max_num_tokens * num_dispatchers,
max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2))
output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output, a.dtype)
def w8a8_groupgemm_forward(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
q_x: Optional[torch.Tensor] = None,
**_ ):
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
local_num_experts = w1.size(0)
E, max_num_tokens, _, _, top_k = mk._moe_problem_size(
q_x, w1, w2, topk_ids)
N, K = self.N, self.K
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.groupgemm_workspace_shapes(
x, q_x, max_num_tokens, N, K, top_k, global_num_experts,
local_num_experts)
workspace13 = torch.empty(prod(workspace13_shape),
device=x.device,
dtype=workspace_dtype)
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
fused_out = _resize_cache(workspace13, fused_out_shape)
# (from deepgemm docs) : A value hint (which is a value on CPU)
# for the M expectation of each batch, correctly setting this value
# may lead to better performance.
expected_m = max_num_tokens // 2
# print("##########################gemm1 workspace1 shape:{} q_x shape:{} " \
# "a1_scale shape:{} w1 shape:{} expert_num_tokens:{} expected_m:{}".format(workspace1.shape,
# q_x.shape,
# a1_scale.shape,
# w1.shape,
# expert_num_tokens,
# expected_m))
m_grouped_w8a8_gemm_nt_masked((q_x, a1_scale),
(w1, w1_scale),
workspace1,
expert_num_tokens,
expected_m,
)
assert expert_num_tokens is not None
a2q, a2q_scale = fuse_silu_mul_quant_ep(workspace1, expert_num_tokens)
# print("##########################gemm2 workspace1 shape:{} a2q shape:{} " \
# "a2q_scale shape:{} w2 shape:{} fused_out shape:{} expert_num_tokens:{} expected_m:{}".format(workspace1.shape,
# a2q.shape,
# a2q_scale.shape,
# w2.shape,
# fused_out.shape,
# expert_num_tokens,
# expected_m))
m_grouped_w8a8_gemm_nt_masked((a2q, a2q_scale),
(w2, w2_scale),
fused_out,
expert_num_tokens,
expected_m)
return fused_out
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe import (
TritonOrGroupGemmExperts)
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
max_num_tokens_per_rank = (
prepare_finalize.max_num_tokens_per_rank())
assert max_num_tokens_per_rank is not None
self.max_num_tokens_per_rank = max_num_tokens_per_rank
logger.debug(
"TritonOrGroupGemmExperts(%s): "
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
self.__class__.__name__, max_num_tokens_per_rank,
None, True)
return TritonOrGroupGemmExperts(
use_int8_w8a8=True,
per_act_token_quant=True,
fused_experts=self.w8a8_groupgemm_forward
)
else:
logger.debug(
"TritonOrGroupGemmExperts(%s): block_size=%s, per_act_token=%s",
self.__class__.__name__, None,
False)
return TritonOrGroupGemmExperts(
fused_experts=self.fused_moe_forward
)
class CompressedTensorsWNA16MarlinMoEMethod(CompressedTensorsMoEMethod):
......
......@@ -4,17 +4,27 @@
import enum
from enum import Enum
from typing import Callable, Optional
from math import prod
import torch
from compressed_tensors.quantization import (QuantizationStrategy)
import vllm.envs as envs
from vllm.logger import init_logger
from vllm.config import get_current_vllm_config
from vllm.distributed import get_tensor_model_parallel_world_size, get_ep_group, get_dp_group
from torch.nn.parameter import Parameter
from vllm.model_executor.layers.fused_moe import (
FusedMoE, FusedMoEActivationFormat, FusedMoEMethodBase,
FusedMoeWeightScaleSupported)
FusedMoEConfig, FusedMoeWeightScaleSupported,
FusedMoEPermuteExpertsUnpermute, FusedMoEPrepareAndFinalize,)
from vllm.model_executor.utils import set_weight_attrs
from vllm.model_executor.layers.fused_moe.utils import _resize_cache
from vllm.model_executor.layers.quantization.utils.w8a8_utils import(
get_w8a8_int8_marlin_weights)
get_w8a8_int8_marlin_weights, w8a8_nt_kpack2_marlin_weight)
try:
from lightop import m_grouped_w8a8_gemm_nt_masked, fuse_silu_mul_quant_ep
from lmslim.layers.fused_moe.fuse_moe_int8_marlin import fused_experts_impl_int8_marlin
except Exception:
print("INFO: Please install lmslim if you want to infer the quantitative model of moe.\n")
......@@ -69,11 +79,28 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
"For INT8 Fused MoE layers, we require channelwise, "
"dynamic per token quantization. Found static input scales.")
self.fused_experts = self.fused_moe_forward
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
dp_size = get_dp_group().world_size
self.use_deepep = dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
if self.use_deepep:
all2all_manager = get_ep_group().device_communicator.all2all_manager
assert all2all_manager is not None
self.num_dispatchers = all2all_manager.world_size
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
params_dtype: torch.dtype, **extra_weight_attrs):
if self.use_deepep:
self.N = 2 * intermediate_size_per_partition
self.K = hidden_size
params_dtype = torch.int8
# WEIGHTS
......@@ -124,20 +151,152 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w1_marlin_list = []
for ii in range(layer.w13_weight.shape[0]):
if not self.use_deepep:
w1_marlin_in = get_w8a8_int8_marlin_weights(layer.w13_weight[ii])
else:
w1_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w13_weight[ii])
w1_marlin_list.append(w1_marlin_in)
w1_marlin = torch.stack(w1_marlin_list, dim=0)
del w1_marlin_list
w2_marlin_list = []
for ii in range(layer.w2_weight.shape[0]):
if not self.use_deepep:
w2_marlin_in = get_w8a8_int8_marlin_weights(layer.w2_weight[ii])
else:
w2_marlin_in = w8a8_nt_kpack2_marlin_weight(layer.w2_weight[ii])
w2_marlin_list.append(w2_marlin_in)
w2_marlin = torch.stack(w2_marlin_list, dim=0)
w2_marlin = torch.stack(w2_marlin_list, dim=0)
layer.w13_weight = Parameter(w1_marlin, requires_grad=False)
layer.w2_weight = Parameter(w2_marlin, requires_grad=False)
def groupgemm_workspace_shapes(self,
a: torch.Tensor,
aq: torch.Tensor,
M: int,
N: int,
K: int,
topk: int,
global_num_experts: int,
local_num_experts: int,):
assert a.dim() == 2
# FIXME (varun): We should be able to dispatch only from the leader
# DP ranks in the case of TP > 1. At the moment, all the Ranks
# end up sending their tokens. This needs to be fixed.
num_dispatchers = self.num_dispatchers
num_experts = local_num_experts
max_num_tokens = a.size(
0) if self.max_num_tokens_per_rank is None else self.max_num_tokens_per_rank
workspace13 = (num_experts, max_num_tokens * num_dispatchers,
max(K, N))
workspace2 = (num_experts, max_num_tokens * num_dispatchers, (N // 2))
output = (num_experts, max_num_tokens * num_dispatchers, K)
return (workspace13, workspace2, output, a.dtype)
def w8a8_groupgemm_forward(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
q_x: Optional[torch.Tensor] = None,
**_ ):
import vllm.model_executor.layers.fused_moe.modular_kernel as mk
local_num_experts = w1.size(0)
E, max_num_tokens, _, _, top_k = mk._moe_problem_size(
q_x, w1, w2, topk_ids)
N, K = self.N, self.K
(workspace13_shape, workspace2_shape, fused_out_shape,
workspace_dtype) = self.groupgemm_workspace_shapes(
x, q_x, max_num_tokens, N, K, top_k, global_num_experts,
local_num_experts)
workspace13 = torch.empty(prod(workspace13_shape),
device=x.device,
dtype=workspace_dtype)
workspace1 = _resize_cache(workspace13, (E, max_num_tokens, N))
fused_out = _resize_cache(workspace13, fused_out_shape)
# (from deepgemm docs) : A value hint (which is a value on CPU)
# for the M expectation of each batch, correctly setting this value
# may lead to better performance.
expected_m = max_num_tokens
m_grouped_w8a8_gemm_nt_masked((q_x, a1_scale),
(w1, w1_scale),
workspace1,
expert_num_tokens,
expected_m,
)
assert expert_num_tokens is not None
a2q, a2q_scale = fuse_silu_mul_quant_ep(workspace1, expert_num_tokens)
m_grouped_w8a8_gemm_nt_masked((a2q, a2q_scale),
(w2, w2_scale),
fused_out,
expert_num_tokens,
expected_m)
return fused_out
def fused_moe_forward(self,
x: torch.Tensor,
w1: torch.Tensor,
w2: torch.Tensor,
topk_ids: torch.Tensor,
topk_weights: torch.Tensor,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
apply_router_weight_on_input: bool = False,
activation: str = "silu",
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
expert_num_tokens: Optional[torch.Tensor] = None,
use_nn_moe: Optional[bool] = False,
routed_scaling_factor: Optional[float] = None,
shared_output: Optional[torch.Tensor] = None,
**_ ):
return fused_experts_impl_int8_marlin(
hidden_states=x,
w1=w1,
w2=w2,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int8_w8a8=True,
per_channel_quant=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
use_nn_moe=False,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor)
def apply(
self,
......@@ -183,25 +342,59 @@ class CompressedTensorsW8A8Int8MarlinMoEMethod(CompressedTensorsMarlinMoEMethod)
scoring_func=scoring_func,
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate,
e_score_correction_bias=e_score_correction_bias)
e_score_correction_bias=e_score_correction_bias,
indices_type=torch.int64 if self.use_deepep else None,)
return fused_experts_impl_int8_marlin(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
return self.fused_experts(
x,
layer.w13_weight,
layer.w2_weight,
topk_weights=topk_weights,
topk_ids=topk_ids,
inplace=True,
activation=activation,
apply_router_weight_on_input=apply_router_weight_on_input,
use_int8_w8a8=True,
per_channel_quant=True,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=layer.w13_weight_scale,
w2_scale=layer.w2_weight_scale,
w1_scale=(layer.w13_weight_scale),
w2_scale=(layer.w2_weight_scale),
a1_scale=layer.w13_input_scale,
a2_scale=layer.w2_input_scale,
use_nn_moe=False,
use_nn_moe=use_nn_moe,
shared_output=shared_output,
routed_scaling_factor=routed_scaling_factor)
\ No newline at end of file
routed_scaling_factor=routed_scaling_factor,
)
def select_gemm_impl(
self,
prepare_finalize: FusedMoEPrepareAndFinalize,
moe: FusedMoEConfig,
) -> FusedMoEPermuteExpertsUnpermute:
from vllm.model_executor.layers.fused_moe import (
TritonOrGroupGemmExperts)
if (prepare_finalize.activation_format ==
FusedMoEActivationFormat.BatchedExperts):
max_num_tokens_per_rank = (
prepare_finalize.max_num_tokens_per_rank())
assert max_num_tokens_per_rank is not None
self.max_num_tokens_per_rank = max_num_tokens_per_rank
logger.debug(
"TritonOrGroupGemmExperts(%s): "
"max_tokens_per_rank=%s, block_size=%s, per_act_token=%s",
self.__class__.__name__, max_num_tokens_per_rank,
None, True)
return TritonOrGroupGemmExperts(
use_int8_w8a8=True,
per_act_token_quant=True,
fused_experts=self.w8a8_groupgemm_forward
)
else:
logger.debug(
"TritonOrGroupGemmExperts(%s): block_size=%s, per_act_token=%s",
self.__class__.__name__, None,
False)
return TritonOrGroupGemmExperts(
fused_experts=self.fused_moe_forward
)
\ No newline at end of file
......@@ -8,7 +8,7 @@ from torch.nn.parameter import Parameter
import vllm.envs as envs
from vllm import _custom_ops as ops
from vllm.model_executor.utils import set_weight_attrs
from vllm.distributed import get_tensor_model_parallel_world_size, get_ep_group
from vllm.distributed import get_tensor_model_parallel_world_size, get_ep_group, get_dp_group
from vllm.logger import init_logger
from vllm.config import get_current_vllm_config
from vllm.forward_context import ForwardContext, get_forward_context
......@@ -163,7 +163,8 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
self.fused_experts = self.w4a8_fused_moe_marlin_forward
vllm_config = get_current_vllm_config()
parallel_config = vllm_config.parallel_config
self.use_deepep = parallel_config.enable_expert_parallel and \
dp_size = get_dp_group().world_size
self.use_deepep = dp_size > 1 and parallel_config.enable_expert_parallel and \
(envs.VLLM_ALL2ALL_BACKEND == "deepep_high_throughput" or \
envs.VLLM_ALL2ALL_BACKEND == "deepep_low_latency")
......@@ -352,16 +353,6 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
# for the M expectation of each batch, correctly setting this value
# may lead to better performance.
expected_m = max_num_tokens
# forward_context = get_forward_context()
# expected_m = forward_context.dp_metadata.max_tokens_across_dp_cpu * self.num_dispatchers
# print("##########################gemm1 workspace1 shape:{} q_x shape:{} " \
# "a1_scale shape:{} w1 shape:{} expert_num_tokens:{} expected_m:{}".format(workspace1.shape,
# q_x.shape,
# a1_scale.shape,
# w1.shape,
# expert_num_tokens,
# expected_m))
m_grouped_w4a8_gemm_nt_masked((q_x, a1_scale),
(w1, w1_scale),
......@@ -371,17 +362,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
)
assert expert_num_tokens is not None
a2q, a2q_scale = fuse_silu_mul_quant_ep(workspace1, expert_num_tokens)
# print("##########################gemm2 workspace1 shape:{} a2q shape:{} " \
# "a2q_scale shape:{} w2 shape:{} fused_out shape:{} expert_num_tokens:{} expected_m:{}".format(workspace1.shape,
# a2q.shape,
# a2q_scale.shape,
# w2.shape,
# fused_out.shape,
# expert_num_tokens,
# expected_m))
m_grouped_w4a8_gemm_nt_masked((a2q, a2q_scale),
(w2, w2_scale),
fused_out,
......@@ -477,6 +458,7 @@ class SlimQuantW4A8Int8MarlinMoEMethod:
routed_scaling_factor=routed_scaling_factor,
use_fused_gate=use_fused_gate
)
return self.fused_experts(
x,
layer.w13_weight,
......
......@@ -40,6 +40,19 @@ def get_w8a8_int8_marlin_weights(
return weight
def w8a8_nt_kpack2_marlin_weight(w8a8_w, # [size_n, size_k// 2 ]
k_tile=16,
n_tile=16, ):
assert w8a8_w.dtype == torch.int8, "w8a8_w 必须是 int8 类型"
size_n, size_k = w8a8_w.shape
assert size_n % k_tile == 0 and size_k % n_tile == 0, "k_tile / n_tile 必须能整除对应维度"
w8a8_w = w8a8_w.reshape((size_n // n_tile, n_tile, size_k // k_tile, k_tile))
w8a8_w = w8a8_w.permute((0, 2, 1, 3)).contiguous()
w8a8_w = w8a8_w.reshape((size_n // k_tile, size_k * k_tile))
return w8a8_w
def sparse_cutlass_supported() -> bool:
if not current_platform.is_cuda():
return False
......
......@@ -1235,7 +1235,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# TODO(tms) : There are many cases where padding is enabled for
# prefills, causing unnecessary and excessive padding of activations.
if dp_size == 1 or self.vllm_config.model_config.enforce_eager or envs.VLLM_ALL2ALL_BACKEND == 'naive':
if dp_size == 1 or self.vllm_config.model_config.enforce_eager or envs.VLLM_ALL2ALL_BACKEND != 'naive':
# Early exit.
return 0, None
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment