Unverified Commit c1d2061f authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Add initial support for gpt-oss (#8824)

parent 556e4143
......@@ -88,6 +88,7 @@ class TritonAttnBackend(AttentionBackend):
self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
if not self.skip_prefill:
self.qo_indptr = torch.zeros(
......@@ -197,6 +198,7 @@ class TritonAttnBackend(AttentionBackend):
forward_batch.req_pool_indices,
bs,
self.device,
self.token_to_kv_pool_allocator,
)
)
window_num_kv_splits = torch.empty(
......@@ -225,7 +227,6 @@ class TritonAttnBackend(AttentionBackend):
mask_indptr = None
max_extend_len = None
elif forward_batch.forward_mode.is_target_verify():
# TODO: Support sliding window in spec inference
bs = len(forward_batch.req_pool_indices)
qo_indptr = torch.arange(
0,
......@@ -250,6 +251,20 @@ class TritonAttnBackend(AttentionBackend):
self.req_to_token.stride(0),
)
if self.sliding_window_size is not None and self.sliding_window_size > 0:
window_kv_indptr, window_kv_indices, window_kv_lens = (
update_sliding_window_buffer(
self.window_kv_indptr,
self.req_to_token,
self.sliding_window_size,
forward_batch.seq_lens,
forward_batch.req_pool_indices,
bs,
self.device,
self.token_to_kv_pool_allocator,
)
)
custom_mask = spec_info.custom_mask
seq_mask_len = self.num_draft_tokens * (
forward_batch.seq_lens + self.num_draft_tokens
......@@ -308,6 +323,7 @@ class TritonAttnBackend(AttentionBackend):
forward_batch.req_pool_indices,
bs,
self.device,
self.token_to_kv_pool_allocator,
)
qo_indptr = self.qo_indptr
......@@ -423,14 +439,17 @@ class TritonAttnBackend(AttentionBackend):
):
window_kv_indices = self.cuda_graph_window_kv_indices
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
window_kv_indptr, _ = update_sliding_window_buffer_cuda_graph(
self.window_kv_indptr,
window_kv_indices,
self.req_to_token,
self.sliding_window_size,
seq_lens[:bs],
req_pool_indices,
bs,
window_kv_indptr, window_kv_indices, _ = (
update_sliding_window_buffer_cuda_graph(
self.window_kv_indptr,
window_kv_indices,
self.req_to_token,
self.sliding_window_size,
seq_lens[:bs],
req_pool_indices,
bs,
self.token_to_kv_pool_allocator,
)
)
else:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
......@@ -464,6 +483,22 @@ class TritonAttnBackend(AttentionBackend):
self.req_to_token.stride(0),
)
if self.sliding_window_size is not None and self.sliding_window_size > 0:
window_kv_indices = self.cuda_graph_window_kv_indices
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
window_kv_indptr, window_kv_indices, _ = (
update_sliding_window_buffer_cuda_graph(
self.window_kv_indptr,
window_kv_indices,
self.req_to_token,
self.sliding_window_size,
seq_lens,
req_pool_indices,
bs,
self.token_to_kv_pool_allocator,
)
)
custom_mask = self.cuda_graph_custom_mask
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
......@@ -557,7 +592,7 @@ class TritonAttnBackend(AttentionBackend):
):
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
window_kv_indices = self.cuda_graph_window_kv_indices
_, window_kv_lens = update_sliding_window_buffer_cuda_graph(
_, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
self.window_kv_indptr,
window_kv_indices,
self.req_to_token,
......@@ -565,6 +600,7 @@ class TritonAttnBackend(AttentionBackend):
seq_lens[:bs],
req_pool_indices[:bs],
bs,
self.token_to_kv_pool_allocator,
)
self.get_num_kv_splits(
window_num_kv_splits[:num_token], window_kv_lens[:bs]
......@@ -599,6 +635,19 @@ class TritonAttnBackend(AttentionBackend):
kv_indices,
self.req_to_token.stride(0),
)
if self.sliding_window_size is not None and self.sliding_window_size > 0:
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
window_kv_indices = self.cuda_graph_window_kv_indices
_, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
self.window_kv_indptr,
window_kv_indices,
self.req_to_token,
self.sliding_window_size,
seq_lens,
req_pool_indices,
bs,
self.token_to_kv_pool_allocator,
)
custom_mask = self.cuda_graph_custom_mask
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
......@@ -637,6 +686,7 @@ class TritonAttnBackend(AttentionBackend):
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
sk=None,
):
# TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim:
......@@ -680,7 +730,8 @@ class TritonAttnBackend(AttentionBackend):
self.forward_metadata.max_extend_len,
layer.scaling,
layer.logit_cap,
sliding_window_size,
sliding_window_size=sliding_window_size,
sk=sk,
)
return o
......@@ -692,6 +743,7 @@ class TritonAttnBackend(AttentionBackend):
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
sk=None,
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
......@@ -728,6 +780,7 @@ class TritonAttnBackend(AttentionBackend):
self.max_kv_splits,
layer.scaling,
layer.logit_cap,
sk=sk,
)
return o
......@@ -932,10 +985,11 @@ def update_sliding_window_buffer(
req_pool_indices,
bs,
device,
token_to_kv_pool_allocator=None,
):
window_kv_lens = torch.minimum(
seq_lens,
torch.tensor(sliding_window_size + 1),
torch.tensor(sliding_window_size),
)
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
window_kv_indptr = window_kv_indptr[: bs + 1]
......@@ -952,6 +1006,14 @@ def update_sliding_window_buffer(
window_kv_indices,
req_to_token.stride(0),
)
# full to swa index mapping
if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"):
kv_last_index = window_kv_indptr[-1]
window_kv_indices[:kv_last_index] = (
token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
window_kv_indices[:kv_last_index]
)
)
return window_kv_indptr, window_kv_indices, window_kv_lens
......@@ -963,10 +1025,11 @@ def update_sliding_window_buffer_cuda_graph(
seq_lens,
req_pool_indices,
bs,
token_to_kv_pool_allocator=None,
):
window_kv_lens = torch.minimum(
seq_lens,
torch.tensor(sliding_window_size + 1),
torch.tensor(sliding_window_size),
)
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
window_kv_indptr = window_kv_indptr[: bs + 1]
......@@ -980,4 +1043,12 @@ def update_sliding_window_buffer_cuda_graph(
window_kv_indices,
req_to_token.stride(0),
)
return window_kv_indptr, window_kv_lens
# full to swa index mapping
if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"):
kv_last_index = window_kv_indptr[-1]
window_kv_indices[:kv_last_index] = (
token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
window_kv_indices[:kv_last_index]
)
)
return window_kv_indptr, window_kv_indices, window_kv_lens
......@@ -495,6 +495,7 @@ def _fwd_kernel_stage2(
O,
kv_indptr,
num_kv_splits,
sk_ptr,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
......@@ -504,6 +505,7 @@ def _fwd_kernel_stage2(
MIN_BLOCK_KV: tl.constexpr,
BLOCK_DV: tl.constexpr,
Lv: tl.constexpr,
HAS_SK: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
......@@ -545,6 +547,10 @@ def _fwd_kernel_stage2(
e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max
if HAS_SK:
cur_sk = tl.load(sk_ptr + cur_head)
e_sum += tl.exp(cur_sk - e_max)
tl.store(
O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
acc / e_sum,
......@@ -561,12 +567,14 @@ def _decode_softmax_reducev_fwd(
kv_indptr,
num_kv_splits,
max_kv_splits,
sk=None,
):
batch, head_num = q.shape[0], q.shape[1]
Lv = v_buffer.shape[-1]
BLOCK_DV = triton.next_power_of_2(Lv)
MAX_KV_SPLITS = max_kv_splits
HAS_SK = sk is not None
extra_kargs = {}
if _is_hip:
......@@ -581,6 +589,7 @@ def _decode_softmax_reducev_fwd(
o,
kv_indptr,
num_kv_splits,
sk,
logits.stride(0),
logits.stride(1),
logits.stride(2),
......@@ -590,6 +599,7 @@ def _decode_softmax_reducev_fwd(
MIN_BLOCK_KV=_MIN_BLOCK_KV,
BLOCK_DV=BLOCK_DV,
Lv=Lv,
HAS_SK=HAS_SK,
num_warps=4,
num_stages=2,
**extra_kargs,
......@@ -609,6 +619,7 @@ def decode_attention_fwd_normal(
max_kv_splits,
sm_scale,
logit_cap=0.0,
sk=None,
):
_decode_att_m_fwd(
q,
......@@ -632,6 +643,7 @@ def decode_attention_fwd_normal(
kv_indptr,
num_kv_splits,
max_kv_splits,
sk,
)
......@@ -648,6 +660,7 @@ def decode_attention_fwd_grouped(
max_kv_splits,
sm_scale,
logit_cap=0.0,
sk=None,
):
_decode_grouped_att_m_fwd(
q,
......@@ -671,6 +684,7 @@ def decode_attention_fwd_grouped(
kv_indptr,
num_kv_splits,
max_kv_splits,
sk,
)
......@@ -687,6 +701,7 @@ def decode_attention_fwd(
max_kv_splits,
sm_scale,
logit_cap=0.0,
sk=None,
):
assert max_kv_splits == attn_logits.shape[2]
assert q.shape[0] <= kv_indptr.shape[0] - 1
......@@ -709,6 +724,7 @@ def decode_attention_fwd(
max_kv_splits,
sm_scale,
logit_cap=logit_cap,
sk=sk,
)
else:
# GQA/MQA/MLA
......@@ -725,4 +741,5 @@ def decode_attention_fwd(
max_kv_splits,
sm_scale,
logit_cap=logit_cap,
sk=sk,
)
......@@ -51,6 +51,7 @@ def _fwd_kernel(
kv_indices,
mask_ptr,
mask_indptr,
sk_ptr,
sm_scale,
kv_group_num,
stride_qbs,
......@@ -78,6 +79,7 @@ def _fwd_kernel(
IS_CAUSAL: tl.constexpr,
SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
STORE_TRANSPOSE: tl.constexpr,
HAS_SK: tl.constexpr,
):
cur_seq = tl.program_id(0)
cur_head = tl.program_id(1)
......@@ -178,13 +180,17 @@ def _fwd_kernel(
final_mask &= custom_mask
if SLIDING_WINDOW_SIZE > 0:
# Add mask where q_id <= kv_id + sliding_window_size
window_mask = (cur_block_m * BLOCK_M + offs_m[:, None]) <= (
start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE
)
# q_id = prefix_len + cur_m, kv_id = cur_n
window_mask = (
cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m[:, None]
) <= (start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE)
final_mask &= window_mask
qk = tl.where(final_mask, qk, float("-inf"))
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
row_max = tl.max(qk, 1)
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
n_e_max = tl.maximum(row_max_fixed, e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
deno = deno * re_scale + tl.sum(p, 1)
......@@ -242,6 +248,7 @@ def _fwd_kernel(
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
final_mask = mask_m[:, None] & mask_n[None, :]
if USE_CUSTOM_MASK:
custom_mask = tl.load(
mask_ptr
......@@ -254,18 +261,30 @@ def _fwd_kernel(
other=0,
)
custom_mask &= mask_m[:, None] & mask_n[None, :]
qk = tl.where(custom_mask, qk, float("-inf"))
final_mask &= custom_mask
elif IS_CAUSAL:
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
start_n + offs_n[None, :]
)
mask_causual &= mask_m[:, None] & mask_n[None, :]
qk = tl.where(mask_causual, qk, float("-inf"))
final_mask &= mask_causual
else:
mask_non_causal = mask_m[:, None] & mask_n[None, :]
qk = tl.where(mask_non_causal, qk, float("-inf"))
final_mask &= mask_non_causal
if SLIDING_WINDOW_SIZE > 0:
# Add mask where q_id <= kv_id + sliding_window_size
window_mask = (cur_block_m * BLOCK_M + offs_m[:, None]) <= (
start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE
)
final_mask &= window_mask
qk = tl.where(final_mask, qk, float("-inf"))
row_max = tl.max(qk, 1)
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
n_e_max = tl.maximum(row_max_fixed, e_max)
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
deno = deno * re_scale + tl.sum(p, 1)
......@@ -283,6 +302,10 @@ def _fwd_kernel(
e_max = n_e_max
if HAS_SK:
cur_sk = tl.load(sk_ptr + cur_head)
deno += tl.exp(cur_sk - e_max)
offs_o = (
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
* stride_obs
......@@ -321,6 +344,7 @@ def extend_attention_fwd(
logit_cap=0.0,
skip_prefix_custom_mask=True,
sliding_window_size=-1,
sk=None,
):
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
......@@ -386,6 +410,8 @@ def extend_attention_fwd(
# Skip custom mask for prefix part
SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask
HAS_SK = sk is not None
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
num_stages = 1
......@@ -405,6 +431,7 @@ def extend_attention_fwd(
kv_indices,
custom_mask,
mask_indptr,
sk,
sm_scale,
kv_group_num,
q_extend.stride(0),
......@@ -431,6 +458,7 @@ def extend_attention_fwd(
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
IS_CAUSAL=is_causal,
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
HAS_SK=HAS_SK,
STORE_TRANSPOSE=_is_hip,
num_warps=num_warps,
num_stages=num_stages,
......
......@@ -1191,11 +1191,6 @@ class RowParallelLinear(LinearBase):
else self.weight_loader
),
)
if not reduce_results and (bias and not skip_bias_add):
raise ValueError(
"When not reduce the results, adding bias to the "
"results can lead to incorrect results"
)
if bias:
self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
......
......@@ -134,6 +134,10 @@ class FusedMoE(torch.nn.Module):
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
enable_flashinfer_cutlass_moe: Optional[bool] = False,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
use_weight_loader_fused: bool = False,
with_bias=False,
):
super().__init__()
......@@ -148,6 +152,10 @@ class FusedMoE(torch.nn.Module):
self.expert_map_cpu = None
self.expert_map_gpu = None
# For activation
self.activation_alpha = activation_alpha
self.swiglu_limit = swiglu_limit
if enable_flashinfer_cutlass_moe and quant_config is None:
logger.warning("Disable flashinfer MoE when quantization config is None.")
enable_flashinfer_cutlass_moe = False
......@@ -191,7 +199,7 @@ class FusedMoE(torch.nn.Module):
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
self.use_triton_kernels
self.use_triton_kernels, with_bias=with_bias
)
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
......@@ -206,7 +214,12 @@ class FusedMoE(torch.nn.Module):
intermediate_size=self.intermediate_size_per_partition,
intermediate_size_per_partition=self.intermediate_size_per_partition,
params_dtype=params_dtype,
weight_loader=self.weight_loader,
weight_loader=(
self.weight_loader
if not use_weight_loader_fused
else self.weight_loader_fused
),
with_bias=with_bias,
)
def _load_per_tensor_weight_scale(
......@@ -234,6 +247,7 @@ class FusedMoE(torch.nn.Module):
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
is_bias: bool = False,
):
# Load grouped weight scales for group quantization
# or model weights
......@@ -244,14 +258,16 @@ class FusedMoE(torch.nn.Module):
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
is_bias=is_bias,
)
elif shard_id in ("w1", "w3"):
elif shard_id in ("w1", "w3", "w13"):
self._load_w13(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
is_bias=is_bias,
)
def _load_per_channel_weight_scale(
......@@ -281,17 +297,30 @@ class FusedMoE(torch.nn.Module):
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
is_bias: bool = False,
):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2
assert shard_id in {"w1", "w3", "w13"}
if is_bias:
# if this weight is a bias, the last dimension must be the sharded dimension
shard_dim = -1
if shard_id in {"w1", "w3"}:
# non-fused version
shard_size = expert_data.shape[shard_dim] // 2
elif shard_id in {"w13"}:
# fused version
shard_size = expert_data.shape[shard_dim]
else:
raise NotImplementedError
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
# w3, up_proj: Load into second logical weight of w13.
# trtllm cutlass kernel assumes differently
assert shard_id in ("w1", "w3")
switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False)
if (switch_w13 and shard_id == "w1") or (not switch_w13 and shard_id == "w3"):
start = shard_size
......@@ -310,7 +339,8 @@ class FusedMoE(torch.nn.Module):
)
else:
if not self.use_presharded_weights:
if self.use_triton_kernels:
if not is_bias and self.use_triton_kernels:
# do not transpose for bias
loaded_weight = loaded_weight.transpose(-2, -1)
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size
......@@ -326,6 +356,7 @@ class FusedMoE(torch.nn.Module):
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
is_bias: bool = False,
):
"""Load w2 weights for down projection.
......@@ -356,7 +387,14 @@ class FusedMoE(torch.nn.Module):
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
shard_size = expert_data.shape[shard_dim]
if is_bias:
# this expert_data is a bias, not weight,
# for w2_bias in TP, it does not need to be sharded
shard_size = expert_data.shape[-1]
else:
# this parameter is a weight matrix
# for w2 in TP, it shards the input_features, i.e., shard_dim=2
shard_size = expert_data.shape[shard_dim]
if _is_cpu:
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
......@@ -369,7 +407,7 @@ class FusedMoE(torch.nn.Module):
not self.use_presharded_weights,
)
else:
if not self.use_presharded_weights:
if not is_bias and not self.use_presharded_weights:
if self.use_triton_kernels:
loaded_weight = loaded_weight.transpose(-2, -1)
if shard_size * tp_rank + shard_size > loaded_weight.shape[shard_dim]:
......@@ -658,6 +696,68 @@ class FusedMoE(torch.nn.Module):
)
return
def weight_loader_fused(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
) -> None:
tp_rank = self.moe_tp_rank
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO: check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
loaded_weight = (
loaded_weight.t().contiguous()
if (
self.quant_method.__class__.__name__
== "CompressedTensorsWNA16MoEMethod"
)
else loaded_weight
)
if shard_id not in ("w13", "w2"):
raise ValueError(f"shard_id must be ['w13','w2'] but " f"got {shard_id}.")
# Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever
# dimension intermediate_size is used.
SHARD_ID_TO_SHARDED_DIM = {"w13": 1, "w2": 2}
SHARD_ID_TO_SHARDED_DIM_TRANSPOSE = {"w13": 2, "w2": 1}
expert_data = param.data
is_bias = expert_data.dim() == 2
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size is
is_transposed = getattr(param, "is_transposed", False)
if self.use_triton_kernels:
is_transposed = True
shard_dim = (
SHARD_ID_TO_SHARDED_DIM[shard_id]
if not is_transposed
else SHARD_ID_TO_SHARDED_DIM_TRANSPOSE[shard_id]
)
# Case model weights
if "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
is_bias=is_bias,
)
return
else:
logging.warning(
f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded."
)
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
assert self.quant_method is not None
......@@ -673,6 +773,12 @@ class FusedMoE(torch.nn.Module):
# Matrix multiply.
with use_symmetric_memory(get_tp_group()) as sm:
kwargs = {}
if self.activation_alpha is not None:
kwargs["activation_alpha"] = self.activation_alpha
if self.swiglu_limit is not None:
kwargs["swiglu_limit"] = self.swiglu_limit
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
......@@ -691,6 +797,7 @@ class FusedMoE(torch.nn.Module):
== "ModelOptNvFp4FusedMoEMethod"
else {}
),
**kwargs,
)
sm.tag(final_hidden_states)
......@@ -728,6 +835,25 @@ class FusedMoE(torch.nn.Module):
]
]
@classmethod
def make_expert_params_mapping_fused(
cls,
ckpt_gate_up_proj_name: str,
ckpt_down_proj_name: str,
ckpt_gate_up_proj_bias_name: str,
ckpt_down_proj_bias_name: str,
):
return [
("experts.w13_weight", f"experts.{ckpt_gate_up_proj_name}", "w13"),
(
"experts.w13_weight_bias",
f"experts.{ckpt_gate_up_proj_bias_name}",
"w13",
),
("experts.w2_weight", f"experts.{ckpt_down_proj_name}", "w2"),
("experts.w2_weight_bias", f"experts.{ckpt_down_proj_bias_name}", "w2"),
]
@classmethod
def make_expert_input_scale_params_mapping(
cls,
......
......@@ -6,15 +6,50 @@ from typing import TYPE_CHECKING, Optional
import torch
from sgl_kernel import gelu_and_mul, silu_and_mul
from triton_kernels.matmul_ogs import matmul_ogs
from triton_kernels.matmul_ogs import (
FlexCtx,
FnSpecs,
FusedActivation,
PrecisionConfig,
matmul_ogs,
)
from triton_kernels.numerics import InFlexData
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
from sglang.srt.utils import direct_register_custom_op
from triton_kernels.swiglu import swiglu_fn
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
def quantize(w, dtype, dev, **opt):
if dtype == "bf16":
return w.to(torch.bfloat16), InFlexData()
elif dtype == "fp8":
wq = w.to(torch.float8_e4m3fn).transpose(-1, -2).contiguous().transpose(-1, -2)
return (
wq,
InFlexData(dtype=wq.dtype, scale=w.abs().max().unsqueeze(0)),
MicroscalingCtx(),
)
else:
assert dtype == "mx4", f"{dtype=}"
swizzle_mx_scale = opt["swizzle_mx_scale"]
swizzle_axis = 2 if swizzle_mx_scale else None
w = w.to(torch.bfloat16)
w, mx_scales, weight_scale_shape = downcast_to_mxfp(
w, torch.uint8, axis=1, swizzle_axis=swizzle_axis
)
return (
w,
InFlexData(),
MicroscalingCtx(
weight_scale=mx_scales,
swizzle_mx=swizzle_mx_scale,
actual_weight_scale_shape=weight_scale_shape,
),
)
def triton_kernel_moe_forward(
hidden_states: torch.Tensor,
w1: torch.Tensor,
......@@ -146,3 +181,143 @@ def triton_kernel_fused_experts(
)
return intermediate_cache3
def triton_kernel_moe_with_bias_forward(
hidden_states: torch.Tensor,
w1: torch.Tensor,
b1: torch.Tensor,
w2: torch.Tensor,
b2: torch.Tensor,
topk_output: TopKOutput,
inplace: bool = False,
activation: str = "silu",
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[int] = None,
) -> torch.Tensor:
assert topk_output.format.is_triton_kernel()
routing_data, gather_idx, scatter_idx = topk_output
return triton_kernel_fused_experts_with_bias(
hidden_states,
w1,
b1,
w2,
b2,
routing_data,
gather_idx,
scatter_idx,
inplace=inplace,
activation=activation,
use_fp8_w8a8=use_fp8_w8a8,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
)
def triton_kernel_fused_experts_with_bias(
hidden_states: torch.Tensor,
w1: torch.Tensor,
b1: torch.Tensor,
w2: torch.Tensor,
b2: torch.Tensor,
routing_data: RoutingData,
gather_indx: GatherIndx,
scatter_indx: ScatterIndx,
inplace: bool = False,
activation: str = "silu",
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[int] = None,
) -> torch.Tensor:
# print(f"here in triton moe with bias", b1.shape, b1.dtype, b2.shape, b2.dtype)
assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
assert per_channel_quant == False, "per_channel_quant is not supported"
assert expert_map == None, "expert_map is not supported"
assert w1_scale == None, "w1_scale is not supported"
assert w2_scale == None, "w2_scale is not supported"
assert a1_scale == None, "a1_scale is not supported"
assert a2_scale == None, "a2_scale is not supported"
assert block_shape == None, "block_shape is not supported"
# type check
assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
assert w1.dtype == torch.bfloat16, "w1 must be bfloat16"
assert w2.dtype == torch.bfloat16, "w2 must be bfloat16"
# Shape check
assert hidden_states.ndim == 2, "hidden_states must be 2D"
assert (
hidden_states.shape[-1] == w1.shape[-2]
), f"hidden_states shape[-1] {hidden_states.shape} must be equal to w1 shape[-2] {w1.shape}"
assert (
w2.shape[-1] == w1.shape[1]
), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
# feature check
assert inplace == False, "Inplace is not supported in new triton MoE kernel"
E, _, _ = w1.shape
if global_num_experts == -1:
global_num_experts = E
device = "cuda"
optg = dict()
w1, w1_flex = quantize(w1, "bf16", device, **optg)
w1_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex))
w2, w2_flex = quantize(w2, "bf16", device, **optg)
w2_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex))
act = FusedActivation(
FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),
(activation_alpha, swiglu_limit),
2,
)
intermediate_cache = matmul_ogs(
hidden_states,
w1,
b1,
routing_data,
gather_indx=gather_indx,
precision_config=w1_pcg,
gammas=None,
fused_activation=act,
)
return matmul_ogs(
intermediate_cache,
w2,
b2,
routing_data,
scatter_indx=scatter_indx,
precision_config=w2_pcg,
gammas=routing_data.gate_scal,
)
......@@ -4,6 +4,7 @@ import torch
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
from sglang.srt.layers.utils import is_sm100_supported
try:
......@@ -26,6 +27,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
)
from sglang.srt.utils import (
align,
ceil_div,
get_bool_env_var,
get_cuda_version,
get_device_capability,
......@@ -307,6 +309,33 @@ def triton_w8a8_block_fp8_linear(
return output.to(dtype=input_2d.dtype).view(*output_shape)
def dequant_mxfp4(
w_block: torch.Tensor,
w_scale: torch.Tensor,
out_dtype,
) -> torch.Tensor:
"""
:param w_block: (batch, n, k, 16), uint8, pack two mxfp4 into one byte
:param w_scale: (batch, n, k), uint8
:return: (batch, n, k * 32), float32
"""
assert w_block.dtype == torch.uint8
assert w_scale.dtype == torch.uint8
batch, n, k, pack_dim = w_block.shape
batch_, n_, k_ = w_scale.shape
assert pack_dim == 16
assert batch == batch_
assert n == n_
assert k == k_
out_raw = MXFP4QuantizeUtil.dequantize(
quantized_data=w_block, scale=w_scale, dtype=out_dtype, block_sizes=[32]
)
return out_raw.reshape(batch, n, k * 32)
def input_to_float8(
x: torch.Tensor, dtype: torch.dtype = fp8_dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
......
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
# https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/quantization/qtensor/mxfp4_tensor.py
class MXFP4QuantizeUtil:
E2M1_max = 6.0
E2M1_values = [0, 0.5, 1, 1.5, 2, 3, 4, 6]
E2M1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5])
@classmethod
def quantize(cls, input: torch.Tensor, block_size: int | None) -> tuple:
"""Converting a tensor to a quantized format based on MXFP4 quantization. Only E4M3 is supported.
Args:
input (torch.Tensor): The input tensor to be quantized.
block_sizes (dict | None): The block sizes for quantization.
"""
def cast_fp4(x):
sign = torch.sign(x)
sign_bit = (2 - sign) // 2
ord_ = torch.sum(
(x.abs().unsqueeze(-1) - cls.E2M1_bounds.to(x.device)) > 0, dim=-1
)
fp4_val = (sign_bit * 0b1000 + ord_).to(torch.uint8)
return fp4_val
def fuse_uint4_to_uint8(x):
# If the last dimension is odd, pad with zeros
# If this behavior is not desired, please modify the code accordingly
left_side = x[..., 0::2] # Even indices (0, 2, 4...)
right_side = x[..., 1::2] # Odd indices (1, 3, 5...)
new_data = (
right_side.clone() << 4
) # Put odd indices (higher addresses) in high bits
new_data[
..., : left_side.shape[-1]
] += left_side # Put even indices in low bits
return new_data
if block_size is None:
block_size = 32
original_shape = input.shape
original_dtype = input.dtype
input = input.view(-1, block_size)
# get scales
input_amax = input.abs().max(dim=-1, keepdim=True).values
descale = input_amax / cls.E2M1_max
min_value = torch.tensor(-127.0, device=descale.device)
e8m0_scale = torch.ceil(torch.maximum(torch.log2(descale), min_value))
input = (input / torch.exp2(e8m0_scale)).view(original_shape)
input_q = cast_fp4(input)
input_q = fuse_uint4_to_uint8(input_q)
e8m0_scale = (e8m0_scale + 127).to(torch.uint8)
return cls(original_shape, original_dtype, input_q), e8m0_scale
@classmethod
def dequantize(cls, quantized_data, dtype: torch.dtype, scale, block_sizes):
"""Dequantze MXFP4 packed tensor to a target dtype."""
def unfuse_uint8_to_uint4(x):
"""Unfuse uint8 values back to uint4 values.
This is the inverse operation of fuse_uint4_to_uint8.
"""
# Extract the lower 4 bits (even indices)
left_side = x & 0x0F
# Extract the upper 4 bits (odd indices)
right_side = (x >> 4) & 0x0F
# Create a new tensor with alternating values
shape = list(x.shape)
shape[-1] = shape[-1] * 2
result = torch.zeros(shape, dtype=torch.uint8, device=x.device)
# Fill in the values - even indices get low bits, odd indices get high bits
result[..., 0::2] = left_side # Even indices from low bits
result[..., 1::2] = right_side # Odd indices from high bits
return result
e8m0_scale = scale
block_size = block_sizes[-1]
# Unfuse the uint8 values back to uint4
x_unfused = unfuse_uint8_to_uint4(quantized_data)
# Extract sign and magnitude
sign = 1 - 2 * ((x_unfused & 0b1000) >> 3).to(
torch.float32
) # Extract sign bit and convert to +1/-1
magnitude = x_unfused & 0b0111 # Extract magnitude bits
magnitude = magnitude.to(torch.long)
# Create a tensor with the E2M1 values
values = torch.tensor(cls.E2M1_values, device=quantized_data.device)
# Use gather to index the values tensor properly
# We need to reshape magnitude to match the dimensions we want to gather along
original_shape = magnitude.shape
x_float = values[magnitude.reshape(-1)].reshape(original_shape)
# Apply sign and scale
x_float = sign.float() * x_float
# Reshape to apply block-wise scaling
x_float = x_float.reshape(-1, block_size)
# Apply the E8M0 scale
scale_factor = torch.exp2(e8m0_scale.float() - 127)
scale_factor = scale_factor.reshape(-1, 1) # Reshape for proper broadcasting
# Apply scaling and reshape back to original shape
x_float = x_float * scale_factor
# Reshape back to the original shape
return x_float.reshape(original_shape).to(dtype)
......@@ -126,17 +126,23 @@ class UnquantizedLinearMethod(LinearMethodBase):
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""
def __init__(self, use_triton_kernels: bool = False):
def __init__(self, use_triton_kernels: bool = False, with_bias: bool = False):
super().__init__()
self.use_triton_kernels = use_triton_kernels
self.with_bias = with_bias
self.triton_kernel_moe_forward = None
self.triton_kernel_moe_with_bias_forward = None
if torch.cuda.is_available() and has_triton_kernels:
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_forward as _tk_forward,
)
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_with_bias_forward as _tk_with_bias_forward,
)
self.triton_kernel_moe_forward = _tk_forward
self.triton_kernel_moe_with_bias_forward = _tk_with_bias_forward
def create_weights(
self,
......@@ -158,6 +164,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
if self.with_bias:
w13_weight_bias = torch.nn.Parameter(
torch.empty(num_experts, 2 * intermediate_size, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_bias", w13_weight_bias)
set_weight_attrs(w13_weight_bias, extra_weight_attrs)
# down_proj (row parallel)
w2_weight_n, w2_weight_k = (
hidden_size,
......@@ -172,6 +186,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
if self.with_bias:
w2_weight_bias = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w2_weight_bias", w2_weight_bias)
set_weight_attrs(w2_weight_bias, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if _use_aiter:
layer.w13_weight = torch.nn.Parameter(
......@@ -202,7 +224,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
) -> torch.Tensor:
kwargs = {}
if activation_alpha is not None:
kwargs["activation_alpha"] = activation_alpha
if swiglu_limit is not None:
kwargs["swiglu_limit"] = swiglu_limit
return self.forward(
x=x,
......@@ -213,6 +242,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
**kwargs,
)
def forward_cuda(
......@@ -226,15 +256,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
) -> torch.Tensor:
if self.use_triton_kernels:
return self.triton_kernel_moe_forward(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_output=topk_output,
)
if self.with_bias:
return self.triton_kernel_moe_with_bias_forward(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
b1=layer.w13_weight_bias,
b2=layer.w2_weight_bias,
topk_output=topk_output,
activation=activation,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
)
else:
return self.triton_kernel_moe_forward(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_output=topk_output,
)
else:
if _use_aiter:
assert not no_combine, "unsupported"
......
......@@ -917,8 +917,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
is_hybrid = False
if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
assert isinstance(tree_cache, SWARadixCache) or isinstance(
tree_cache, SWAChunkCache
assert (
tree_cache is None
or isinstance(tree_cache, SWARadixCache)
or isinstance(tree_cache, SWAChunkCache)
), "SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
is_hybrid = True
......
This diff is collapsed.
......@@ -457,6 +457,10 @@ class ServerArgs:
raise ValueError(
"trtllm_mla backend does not support speculative decoding yet."
)
model_arch = self.get_hf_config().architectures[0]
if model_arch in ["GptOssForCausalLM"]:
self.attention_backend = "triton"
self.enable_triton_kernel_moe = True
# Set page size
if self.page_size is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment