Unverified Commit c1d2061f authored by Ying Sheng's avatar Ying Sheng Committed by GitHub
Browse files

Add initial support for gpt-oss (#8824)

parent 556e4143
......@@ -88,6 +88,7 @@ class TritonAttnBackend(AttentionBackend):
self.window_kv_indptr = torch.zeros_like(kv_indptr_buf)
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.token_to_kv_pool_allocator = model_runner.token_to_kv_pool_allocator
if not self.skip_prefill:
self.qo_indptr = torch.zeros(
......@@ -197,6 +198,7 @@ class TritonAttnBackend(AttentionBackend):
forward_batch.req_pool_indices,
bs,
self.device,
self.token_to_kv_pool_allocator,
)
)
window_num_kv_splits = torch.empty(
......@@ -225,7 +227,6 @@ class TritonAttnBackend(AttentionBackend):
mask_indptr = None
max_extend_len = None
elif forward_batch.forward_mode.is_target_verify():
# TODO: Support sliding window in spec inference
bs = len(forward_batch.req_pool_indices)
qo_indptr = torch.arange(
0,
......@@ -250,6 +251,20 @@ class TritonAttnBackend(AttentionBackend):
self.req_to_token.stride(0),
)
if self.sliding_window_size is not None and self.sliding_window_size > 0:
window_kv_indptr, window_kv_indices, window_kv_lens = (
update_sliding_window_buffer(
self.window_kv_indptr,
self.req_to_token,
self.sliding_window_size,
forward_batch.seq_lens,
forward_batch.req_pool_indices,
bs,
self.device,
self.token_to_kv_pool_allocator,
)
)
custom_mask = spec_info.custom_mask
seq_mask_len = self.num_draft_tokens * (
forward_batch.seq_lens + self.num_draft_tokens
......@@ -308,6 +323,7 @@ class TritonAttnBackend(AttentionBackend):
forward_batch.req_pool_indices,
bs,
self.device,
self.token_to_kv_pool_allocator,
)
qo_indptr = self.qo_indptr
......@@ -423,14 +439,17 @@ class TritonAttnBackend(AttentionBackend):
):
window_kv_indices = self.cuda_graph_window_kv_indices
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
window_kv_indptr, _ = update_sliding_window_buffer_cuda_graph(
self.window_kv_indptr,
window_kv_indices,
self.req_to_token,
self.sliding_window_size,
seq_lens[:bs],
req_pool_indices,
bs,
window_kv_indptr, window_kv_indices, _ = (
update_sliding_window_buffer_cuda_graph(
self.window_kv_indptr,
window_kv_indices,
self.req_to_token,
self.sliding_window_size,
seq_lens[:bs],
req_pool_indices,
bs,
self.token_to_kv_pool_allocator,
)
)
else:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
......@@ -464,6 +483,22 @@ class TritonAttnBackend(AttentionBackend):
self.req_to_token.stride(0),
)
if self.sliding_window_size is not None and self.sliding_window_size > 0:
window_kv_indices = self.cuda_graph_window_kv_indices
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
window_kv_indptr, window_kv_indices, _ = (
update_sliding_window_buffer_cuda_graph(
self.window_kv_indptr,
window_kv_indices,
self.req_to_token,
self.sliding_window_size,
seq_lens,
req_pool_indices,
bs,
self.token_to_kv_pool_allocator,
)
)
custom_mask = self.cuda_graph_custom_mask
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
......@@ -557,7 +592,7 @@ class TritonAttnBackend(AttentionBackend):
):
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
window_kv_indices = self.cuda_graph_window_kv_indices
_, window_kv_lens = update_sliding_window_buffer_cuda_graph(
_, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
self.window_kv_indptr,
window_kv_indices,
self.req_to_token,
......@@ -565,6 +600,7 @@ class TritonAttnBackend(AttentionBackend):
seq_lens[:bs],
req_pool_indices[:bs],
bs,
self.token_to_kv_pool_allocator,
)
self.get_num_kv_splits(
window_num_kv_splits[:num_token], window_kv_lens[:bs]
......@@ -599,6 +635,19 @@ class TritonAttnBackend(AttentionBackend):
kv_indices,
self.req_to_token.stride(0),
)
if self.sliding_window_size is not None and self.sliding_window_size > 0:
window_num_kv_splits = self.cuda_graph_window_num_kv_splits
window_kv_indices = self.cuda_graph_window_kv_indices
_, _, window_kv_lens = update_sliding_window_buffer_cuda_graph(
self.window_kv_indptr,
window_kv_indices,
self.req_to_token,
self.sliding_window_size,
seq_lens,
req_pool_indices,
bs,
self.token_to_kv_pool_allocator,
)
custom_mask = self.cuda_graph_custom_mask
custom_mask[: spec_info.custom_mask.shape[0]] = spec_info.custom_mask
seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens)
......@@ -637,6 +686,7 @@ class TritonAttnBackend(AttentionBackend):
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
sk=None,
):
# TODO: reuse the buffer across layers
if layer.qk_head_dim != layer.v_head_dim:
......@@ -680,7 +730,8 @@ class TritonAttnBackend(AttentionBackend):
self.forward_metadata.max_extend_len,
layer.scaling,
layer.logit_cap,
sliding_window_size,
sliding_window_size=sliding_window_size,
sk=sk,
)
return o
......@@ -692,6 +743,7 @@ class TritonAttnBackend(AttentionBackend):
layer: RadixAttention,
forward_batch: ForwardBatch,
save_kv_cache=True,
sk=None,
):
# During torch.compile, there is a bug in rotary_emb that causes the
# output value to have a 3D tensor shape. This reshapes the output correctly.
......@@ -728,6 +780,7 @@ class TritonAttnBackend(AttentionBackend):
self.max_kv_splits,
layer.scaling,
layer.logit_cap,
sk=sk,
)
return o
......@@ -932,10 +985,11 @@ def update_sliding_window_buffer(
req_pool_indices,
bs,
device,
token_to_kv_pool_allocator=None,
):
window_kv_lens = torch.minimum(
seq_lens,
torch.tensor(sliding_window_size + 1),
torch.tensor(sliding_window_size),
)
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
window_kv_indptr = window_kv_indptr[: bs + 1]
......@@ -952,6 +1006,14 @@ def update_sliding_window_buffer(
window_kv_indices,
req_to_token.stride(0),
)
# full to swa index mapping
if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"):
kv_last_index = window_kv_indptr[-1]
window_kv_indices[:kv_last_index] = (
token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
window_kv_indices[:kv_last_index]
)
)
return window_kv_indptr, window_kv_indices, window_kv_lens
......@@ -963,10 +1025,11 @@ def update_sliding_window_buffer_cuda_graph(
seq_lens,
req_pool_indices,
bs,
token_to_kv_pool_allocator=None,
):
window_kv_lens = torch.minimum(
seq_lens,
torch.tensor(sliding_window_size + 1),
torch.tensor(sliding_window_size),
)
window_kv_indptr[1 : bs + 1] = torch.cumsum(window_kv_lens, dim=0)
window_kv_indptr = window_kv_indptr[: bs + 1]
......@@ -980,4 +1043,12 @@ def update_sliding_window_buffer_cuda_graph(
window_kv_indices,
req_to_token.stride(0),
)
return window_kv_indptr, window_kv_lens
# full to swa index mapping
if hasattr(token_to_kv_pool_allocator, "translate_loc_from_full_to_swa"):
kv_last_index = window_kv_indptr[-1]
window_kv_indices[:kv_last_index] = (
token_to_kv_pool_allocator.translate_loc_from_full_to_swa(
window_kv_indices[:kv_last_index]
)
)
return window_kv_indptr, window_kv_indices, window_kv_lens
......@@ -495,6 +495,7 @@ def _fwd_kernel_stage2(
O,
kv_indptr,
num_kv_splits,
sk_ptr,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
......@@ -504,6 +505,7 @@ def _fwd_kernel_stage2(
MIN_BLOCK_KV: tl.constexpr,
BLOCK_DV: tl.constexpr,
Lv: tl.constexpr,
HAS_SK: tl.constexpr,
):
cur_batch = tl.program_id(0)
cur_head = tl.program_id(1)
......@@ -545,6 +547,10 @@ def _fwd_kernel_stage2(
e_sum = e_sum * old_scale + exp_logic
e_max = n_e_max
if HAS_SK:
cur_sk = tl.load(sk_ptr + cur_head)
e_sum += tl.exp(cur_sk - e_max)
tl.store(
O + cur_batch * stride_obs + cur_head * stride_oh + offs_d,
acc / e_sum,
......@@ -561,12 +567,14 @@ def _decode_softmax_reducev_fwd(
kv_indptr,
num_kv_splits,
max_kv_splits,
sk=None,
):
batch, head_num = q.shape[0], q.shape[1]
Lv = v_buffer.shape[-1]
BLOCK_DV = triton.next_power_of_2(Lv)
MAX_KV_SPLITS = max_kv_splits
HAS_SK = sk is not None
extra_kargs = {}
if _is_hip:
......@@ -581,6 +589,7 @@ def _decode_softmax_reducev_fwd(
o,
kv_indptr,
num_kv_splits,
sk,
logits.stride(0),
logits.stride(1),
logits.stride(2),
......@@ -590,6 +599,7 @@ def _decode_softmax_reducev_fwd(
MIN_BLOCK_KV=_MIN_BLOCK_KV,
BLOCK_DV=BLOCK_DV,
Lv=Lv,
HAS_SK=HAS_SK,
num_warps=4,
num_stages=2,
**extra_kargs,
......@@ -609,6 +619,7 @@ def decode_attention_fwd_normal(
max_kv_splits,
sm_scale,
logit_cap=0.0,
sk=None,
):
_decode_att_m_fwd(
q,
......@@ -632,6 +643,7 @@ def decode_attention_fwd_normal(
kv_indptr,
num_kv_splits,
max_kv_splits,
sk,
)
......@@ -648,6 +660,7 @@ def decode_attention_fwd_grouped(
max_kv_splits,
sm_scale,
logit_cap=0.0,
sk=None,
):
_decode_grouped_att_m_fwd(
q,
......@@ -671,6 +684,7 @@ def decode_attention_fwd_grouped(
kv_indptr,
num_kv_splits,
max_kv_splits,
sk,
)
......@@ -687,6 +701,7 @@ def decode_attention_fwd(
max_kv_splits,
sm_scale,
logit_cap=0.0,
sk=None,
):
assert max_kv_splits == attn_logits.shape[2]
assert q.shape[0] <= kv_indptr.shape[0] - 1
......@@ -709,6 +724,7 @@ def decode_attention_fwd(
max_kv_splits,
sm_scale,
logit_cap=logit_cap,
sk=sk,
)
else:
# GQA/MQA/MLA
......@@ -725,4 +741,5 @@ def decode_attention_fwd(
max_kv_splits,
sm_scale,
logit_cap=logit_cap,
sk=sk,
)
......@@ -51,6 +51,7 @@ def _fwd_kernel(
kv_indices,
mask_ptr,
mask_indptr,
sk_ptr,
sm_scale,
kv_group_num,
stride_qbs,
......@@ -78,6 +79,7 @@ def _fwd_kernel(
IS_CAUSAL: tl.constexpr,
SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
STORE_TRANSPOSE: tl.constexpr,
HAS_SK: tl.constexpr,
):
cur_seq = tl.program_id(0)
cur_head = tl.program_id(1)
......@@ -178,13 +180,17 @@ def _fwd_kernel(
final_mask &= custom_mask
if SLIDING_WINDOW_SIZE > 0:
# Add mask where q_id <= kv_id + sliding_window_size
window_mask = (cur_block_m * BLOCK_M + offs_m[:, None]) <= (
start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE
)
# q_id = prefix_len + cur_m, kv_id = cur_n
window_mask = (
cur_seq_len_prefix + cur_block_m * BLOCK_M + offs_m[:, None]
) <= (start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE)
final_mask &= window_mask
qk = tl.where(final_mask, qk, float("-inf"))
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
row_max = tl.max(qk, 1)
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
n_e_max = tl.maximum(row_max_fixed, e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
deno = deno * re_scale + tl.sum(p, 1)
......@@ -242,6 +248,7 @@ def _fwd_kernel(
if logit_cap > 0:
qk = logit_cap * tanh(qk / logit_cap)
final_mask = mask_m[:, None] & mask_n[None, :]
if USE_CUSTOM_MASK:
custom_mask = tl.load(
mask_ptr
......@@ -254,18 +261,30 @@ def _fwd_kernel(
other=0,
)
custom_mask &= mask_m[:, None] & mask_n[None, :]
qk = tl.where(custom_mask, qk, float("-inf"))
final_mask &= custom_mask
elif IS_CAUSAL:
mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
start_n + offs_n[None, :]
)
mask_causual &= mask_m[:, None] & mask_n[None, :]
qk = tl.where(mask_causual, qk, float("-inf"))
final_mask &= mask_causual
else:
mask_non_causal = mask_m[:, None] & mask_n[None, :]
qk = tl.where(mask_non_causal, qk, float("-inf"))
final_mask &= mask_non_causal
if SLIDING_WINDOW_SIZE > 0:
# Add mask where q_id <= kv_id + sliding_window_size
window_mask = (cur_block_m * BLOCK_M + offs_m[:, None]) <= (
start_n + offs_n[None, :] + SLIDING_WINDOW_SIZE
)
final_mask &= window_mask
qk = tl.where(final_mask, qk, float("-inf"))
row_max = tl.max(qk, 1)
row_max_fixed = tl.where(row_max == float("-inf"), -1e20, row_max)
n_e_max = tl.maximum(row_max_fixed, e_max)
n_e_max = tl.maximum(tl.max(qk, 1), e_max)
re_scale = tl.exp(e_max - n_e_max)
p = tl.exp(qk - n_e_max[:, None])
deno = deno * re_scale + tl.sum(p, 1)
......@@ -283,6 +302,10 @@ def _fwd_kernel(
e_max = n_e_max
if HAS_SK:
cur_sk = tl.load(sk_ptr + cur_head)
deno += tl.exp(cur_sk - e_max)
offs_o = (
(cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None])
* stride_obs
......@@ -321,6 +344,7 @@ def extend_attention_fwd(
logit_cap=0.0,
skip_prefix_custom_mask=True,
sliding_window_size=-1,
sk=None,
):
"""
q_extend, k_extend, v_extend, o_extend: contiguous tensors
......@@ -386,6 +410,8 @@ def extend_attention_fwd(
# Skip custom mask for prefix part
SKIP_PREFIX_CUSTOM_MASK = skip_prefix_custom_mask
HAS_SK = sk is not None
grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M))
num_stages = 1
......@@ -405,6 +431,7 @@ def extend_attention_fwd(
kv_indices,
custom_mask,
mask_indptr,
sk,
sm_scale,
kv_group_num,
q_extend.stride(0),
......@@ -431,6 +458,7 @@ def extend_attention_fwd(
USE_CUSTOM_MASK=USE_CUSTOM_MASK,
IS_CAUSAL=is_causal,
SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
HAS_SK=HAS_SK,
STORE_TRANSPOSE=_is_hip,
num_warps=num_warps,
num_stages=num_stages,
......
......@@ -1191,11 +1191,6 @@ class RowParallelLinear(LinearBase):
else self.weight_loader
),
)
if not reduce_results and (bias and not skip_bias_add):
raise ValueError(
"When not reduce the results, adding bias to the "
"results can lead to incorrect results"
)
if bias:
self.bias = Parameter(torch.empty(self.output_size, dtype=params_dtype))
......
......@@ -134,6 +134,10 @@ class FusedMoE(torch.nn.Module):
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
enable_flashinfer_cutlass_moe: Optional[bool] = False,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
use_weight_loader_fused: bool = False,
with_bias=False,
):
super().__init__()
......@@ -148,6 +152,10 @@ class FusedMoE(torch.nn.Module):
self.expert_map_cpu = None
self.expert_map_gpu = None
# For activation
self.activation_alpha = activation_alpha
self.swiglu_limit = swiglu_limit
if enable_flashinfer_cutlass_moe and quant_config is None:
logger.warning("Disable flashinfer MoE when quantization config is None.")
enable_flashinfer_cutlass_moe = False
......@@ -191,7 +199,7 @@ class FusedMoE(torch.nn.Module):
if quant_config is None:
self.quant_method: Optional[QuantizeMethodBase] = UnquantizedFusedMoEMethod(
self.use_triton_kernels
self.use_triton_kernels, with_bias=with_bias
)
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
......@@ -206,7 +214,12 @@ class FusedMoE(torch.nn.Module):
intermediate_size=self.intermediate_size_per_partition,
intermediate_size_per_partition=self.intermediate_size_per_partition,
params_dtype=params_dtype,
weight_loader=self.weight_loader,
weight_loader=(
self.weight_loader
if not use_weight_loader_fused
else self.weight_loader_fused
),
with_bias=with_bias,
)
def _load_per_tensor_weight_scale(
......@@ -234,6 +247,7 @@ class FusedMoE(torch.nn.Module):
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
is_bias: bool = False,
):
# Load grouped weight scales for group quantization
# or model weights
......@@ -244,14 +258,16 @@ class FusedMoE(torch.nn.Module):
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
is_bias=is_bias,
)
elif shard_id in ("w1", "w3"):
elif shard_id in ("w1", "w3", "w13"):
self._load_w13(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
is_bias=is_bias,
)
def _load_per_channel_weight_scale(
......@@ -281,17 +297,30 @@ class FusedMoE(torch.nn.Module):
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
is_bias: bool = False,
):
# Index the loaded weight for tp sharding.
# gate_up_proj: "MergedColumnParallel", so tp sharding on output_dim
shard_size = expert_data.shape[shard_dim] // 2
assert shard_id in {"w1", "w3", "w13"}
if is_bias:
# if this weight is a bias, the last dimension must be the sharded dimension
shard_dim = -1
if shard_id in {"w1", "w3"}:
# non-fused version
shard_size = expert_data.shape[shard_dim] // 2
elif shard_id in {"w13"}:
# fused version
shard_size = expert_data.shape[shard_dim]
else:
raise NotImplementedError
# Narrow parameter and load.
# w1, gate_proj: Load into first logical weight of w13.
# w3, up_proj: Load into second logical weight of w13.
# trtllm cutlass kernel assumes differently
assert shard_id in ("w1", "w3")
switch_w13 = getattr(self.quant_method, "load_up_proj_weight_first", False)
if (switch_w13 and shard_id == "w1") or (not switch_w13 and shard_id == "w3"):
start = shard_size
......@@ -310,7 +339,8 @@ class FusedMoE(torch.nn.Module):
)
else:
if not self.use_presharded_weights:
if self.use_triton_kernels:
if not is_bias and self.use_triton_kernels:
# do not transpose for bias
loaded_weight = loaded_weight.transpose(-2, -1)
loaded_weight = loaded_weight.narrow(
shard_dim, shard_size * tp_rank, shard_size
......@@ -326,6 +356,7 @@ class FusedMoE(torch.nn.Module):
shard_id: str,
loaded_weight: torch.Tensor,
tp_rank: int,
is_bias: bool = False,
):
"""Load w2 weights for down projection.
......@@ -356,7 +387,14 @@ class FusedMoE(torch.nn.Module):
# Index the loaded weight for tp sharding.
# down_proj: "RowParallel" so tp sharding on input_dim
# Narrow parameter and load.
shard_size = expert_data.shape[shard_dim]
if is_bias:
# this expert_data is a bias, not weight,
# for w2_bias in TP, it does not need to be sharded
shard_size = expert_data.shape[-1]
else:
# this parameter is a weight matrix
# for w2 in TP, it shards the input_features, i.e., shard_dim=2
shard_size = expert_data.shape[shard_dim]
if _is_cpu:
expert_data, loaded_weight = narrow_padded_param_and_loaded_weight(
......@@ -369,7 +407,7 @@ class FusedMoE(torch.nn.Module):
not self.use_presharded_weights,
)
else:
if not self.use_presharded_weights:
if not is_bias and not self.use_presharded_weights:
if self.use_triton_kernels:
loaded_weight = loaded_weight.transpose(-2, -1)
if shard_size * tp_rank + shard_size > loaded_weight.shape[shard_dim]:
......@@ -658,6 +696,68 @@ class FusedMoE(torch.nn.Module):
)
return
def weight_loader_fused(
self,
param: torch.nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
) -> None:
tp_rank = self.moe_tp_rank
# compressed-tensors checkpoints with packed weights are stored flipped
# TODO: check self.quant_method.quant_config.quant_format
# against known CompressionFormat enum values that have this quality
loaded_weight = (
loaded_weight.t().contiguous()
if (
self.quant_method.__class__.__name__
== "CompressedTensorsWNA16MoEMethod"
)
else loaded_weight
)
if shard_id not in ("w13", "w2"):
raise ValueError(f"shard_id must be ['w13','w2'] but " f"got {shard_id}.")
# Fetch the dim to shard the parameter/loaded weight
# based on the shard id. This will be whatever
# dimension intermediate_size is used.
SHARD_ID_TO_SHARDED_DIM = {"w13": 1, "w2": 2}
SHARD_ID_TO_SHARDED_DIM_TRANSPOSE = {"w13": 2, "w2": 1}
expert_data = param.data
is_bias = expert_data.dim() == 2
# is_transposed: if the dim to shard the weight
# should be flipped. Required by GPTQ, compressed-tensors
# should be whatever dimension intermediate_size is
is_transposed = getattr(param, "is_transposed", False)
if self.use_triton_kernels:
is_transposed = True
shard_dim = (
SHARD_ID_TO_SHARDED_DIM[shard_id]
if not is_transposed
else SHARD_ID_TO_SHARDED_DIM_TRANSPOSE[shard_id]
)
# Case model weights
if "weight" in weight_name:
self._load_model_weight_or_group_weight_scale(
shard_id=shard_id,
shard_dim=shard_dim,
loaded_weight=loaded_weight,
expert_data=expert_data,
tp_rank=tp_rank,
is_bias=is_bias,
)
return
else:
logging.warning(
f"Unsupported weight_name {weight_name} for FusedMoE weight_loader_fused. Nothing is loaded."
)
def forward(self, hidden_states: torch.Tensor, topk_output: StandardTopKOutput):
assert self.quant_method is not None
......@@ -673,6 +773,12 @@ class FusedMoE(torch.nn.Module):
# Matrix multiply.
with use_symmetric_memory(get_tp_group()) as sm:
kwargs = {}
if self.activation_alpha is not None:
kwargs["activation_alpha"] = self.activation_alpha
if self.swiglu_limit is not None:
kwargs["swiglu_limit"] = self.swiglu_limit
final_hidden_states = self.quant_method.apply(
layer=self,
x=hidden_states,
......@@ -691,6 +797,7 @@ class FusedMoE(torch.nn.Module):
== "ModelOptNvFp4FusedMoEMethod"
else {}
),
**kwargs,
)
sm.tag(final_hidden_states)
......@@ -728,6 +835,25 @@ class FusedMoE(torch.nn.Module):
]
]
@classmethod
def make_expert_params_mapping_fused(
cls,
ckpt_gate_up_proj_name: str,
ckpt_down_proj_name: str,
ckpt_gate_up_proj_bias_name: str,
ckpt_down_proj_bias_name: str,
):
return [
("experts.w13_weight", f"experts.{ckpt_gate_up_proj_name}", "w13"),
(
"experts.w13_weight_bias",
f"experts.{ckpt_gate_up_proj_bias_name}",
"w13",
),
("experts.w2_weight", f"experts.{ckpt_down_proj_name}", "w2"),
("experts.w2_weight_bias", f"experts.{ckpt_down_proj_bias_name}", "w2"),
]
@classmethod
def make_expert_input_scale_params_mapping(
cls,
......
......@@ -6,15 +6,50 @@ from typing import TYPE_CHECKING, Optional
import torch
from sgl_kernel import gelu_and_mul, silu_and_mul
from triton_kernels.matmul_ogs import matmul_ogs
from triton_kernels.matmul_ogs import (
FlexCtx,
FnSpecs,
FusedActivation,
PrecisionConfig,
matmul_ogs,
)
from triton_kernels.numerics import InFlexData
from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
from sglang.srt.utils import direct_register_custom_op
from triton_kernels.swiglu import swiglu_fn
if TYPE_CHECKING:
from sglang.srt.layers.moe.topk import TopKOutput
def quantize(w, dtype, dev, **opt):
if dtype == "bf16":
return w.to(torch.bfloat16), InFlexData()
elif dtype == "fp8":
wq = w.to(torch.float8_e4m3fn).transpose(-1, -2).contiguous().transpose(-1, -2)
return (
wq,
InFlexData(dtype=wq.dtype, scale=w.abs().max().unsqueeze(0)),
MicroscalingCtx(),
)
else:
assert dtype == "mx4", f"{dtype=}"
swizzle_mx_scale = opt["swizzle_mx_scale"]
swizzle_axis = 2 if swizzle_mx_scale else None
w = w.to(torch.bfloat16)
w, mx_scales, weight_scale_shape = downcast_to_mxfp(
w, torch.uint8, axis=1, swizzle_axis=swizzle_axis
)
return (
w,
InFlexData(),
MicroscalingCtx(
weight_scale=mx_scales,
swizzle_mx=swizzle_mx_scale,
actual_weight_scale_shape=weight_scale_shape,
),
)
def triton_kernel_moe_forward(
hidden_states: torch.Tensor,
w1: torch.Tensor,
......@@ -146,3 +181,143 @@ def triton_kernel_fused_experts(
)
return intermediate_cache3
def triton_kernel_moe_with_bias_forward(
hidden_states: torch.Tensor,
w1: torch.Tensor,
b1: torch.Tensor,
w2: torch.Tensor,
b2: torch.Tensor,
topk_output: TopKOutput,
inplace: bool = False,
activation: str = "silu",
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[int] = None,
) -> torch.Tensor:
assert topk_output.format.is_triton_kernel()
routing_data, gather_idx, scatter_idx = topk_output
return triton_kernel_fused_experts_with_bias(
hidden_states,
w1,
b1,
w2,
b2,
routing_data,
gather_idx,
scatter_idx,
inplace=inplace,
activation=activation,
use_fp8_w8a8=use_fp8_w8a8,
per_channel_quant=per_channel_quant,
global_num_experts=global_num_experts,
expert_map=expert_map,
w1_scale=w1_scale,
w2_scale=w2_scale,
a1_scale=a1_scale,
a2_scale=a2_scale,
block_shape=block_shape,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
)
def triton_kernel_fused_experts_with_bias(
hidden_states: torch.Tensor,
w1: torch.Tensor,
b1: torch.Tensor,
w2: torch.Tensor,
b2: torch.Tensor,
routing_data: RoutingData,
gather_indx: GatherIndx,
scatter_indx: ScatterIndx,
inplace: bool = False,
activation: str = "silu",
use_fp8_w8a8: bool = False,
per_channel_quant: bool = False,
global_num_experts: int = -1,
expert_map: Optional[torch.Tensor] = None,
w1_scale: Optional[torch.Tensor] = None,
w2_scale: Optional[torch.Tensor] = None,
a1_scale: Optional[torch.Tensor] = None,
a2_scale: Optional[torch.Tensor] = None,
block_shape: Optional[list[int]] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[int] = None,
) -> torch.Tensor:
# print(f"here in triton moe with bias", b1.shape, b1.dtype, b2.shape, b2.dtype)
assert use_fp8_w8a8 == False, "use_fp8_w8a8 is not supported"
assert per_channel_quant == False, "per_channel_quant is not supported"
assert expert_map == None, "expert_map is not supported"
assert w1_scale == None, "w1_scale is not supported"
assert w2_scale == None, "w2_scale is not supported"
assert a1_scale == None, "a1_scale is not supported"
assert a2_scale == None, "a2_scale is not supported"
assert block_shape == None, "block_shape is not supported"
# type check
assert hidden_states.dtype == torch.bfloat16, "hidden_states must be bfloat16"
assert w1.dtype == torch.bfloat16, "w1 must be bfloat16"
assert w2.dtype == torch.bfloat16, "w2 must be bfloat16"
# Shape check
assert hidden_states.ndim == 2, "hidden_states must be 2D"
assert (
hidden_states.shape[-1] == w1.shape[-2]
), f"hidden_states shape[-1] {hidden_states.shape} must be equal to w1 shape[-2] {w1.shape}"
assert (
w2.shape[-1] == w1.shape[1]
), f"w2 shape[-1] {w2.shape[-1]} must be equal to w1 shape[1] {w1.shape[1]}"
# feature check
assert inplace == False, "Inplace is not supported in new triton MoE kernel"
E, _, _ = w1.shape
if global_num_experts == -1:
global_num_experts = E
device = "cuda"
optg = dict()
w1, w1_flex = quantize(w1, "bf16", device, **optg)
w1_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w1_flex))
w2, w2_flex = quantize(w2, "bf16", device, **optg)
w2_pcg = PrecisionConfig(flex_ctx=FlexCtx(rhs_data=w2_flex))
act = FusedActivation(
FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")),
(activation_alpha, swiglu_limit),
2,
)
intermediate_cache = matmul_ogs(
hidden_states,
w1,
b1,
routing_data,
gather_indx=gather_indx,
precision_config=w1_pcg,
gammas=None,
fused_activation=act,
)
return matmul_ogs(
intermediate_cache,
w2,
b2,
routing_data,
scatter_indx=scatter_indx,
precision_config=w2_pcg,
gammas=routing_data.gate_scal,
)
......@@ -4,6 +4,7 @@ import torch
from sglang.srt.layers.quantization import deep_gemm_wrapper
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
from sglang.srt.layers.utils import is_sm100_supported
try:
......@@ -26,6 +27,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
)
from sglang.srt.utils import (
align,
ceil_div,
get_bool_env_var,
get_cuda_version,
get_device_capability,
......@@ -307,6 +309,33 @@ def triton_w8a8_block_fp8_linear(
return output.to(dtype=input_2d.dtype).view(*output_shape)
def dequant_mxfp4(
w_block: torch.Tensor,
w_scale: torch.Tensor,
out_dtype,
) -> torch.Tensor:
"""
:param w_block: (batch, n, k, 16), uint8, pack two mxfp4 into one byte
:param w_scale: (batch, n, k), uint8
:return: (batch, n, k * 32), float32
"""
assert w_block.dtype == torch.uint8
assert w_scale.dtype == torch.uint8
batch, n, k, pack_dim = w_block.shape
batch_, n_, k_ = w_scale.shape
assert pack_dim == 16
assert batch == batch_
assert n == n_
assert k == k_
out_raw = MXFP4QuantizeUtil.dequantize(
quantized_data=w_block, scale=w_scale, dtype=out_dtype, block_sizes=[32]
)
return out_raw.reshape(batch, n, k * 32)
def input_to_float8(
x: torch.Tensor, dtype: torch.dtype = fp8_dtype
) -> Tuple[torch.Tensor, torch.Tensor]:
......
# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
# https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/quantization/qtensor/mxfp4_tensor.py
class MXFP4QuantizeUtil:
E2M1_max = 6.0
E2M1_values = [0, 0.5, 1, 1.5, 2, 3, 4, 6]
E2M1_bounds = torch.tensor([0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5])
@classmethod
def quantize(cls, input: torch.Tensor, block_size: int | None) -> tuple:
"""Converting a tensor to a quantized format based on MXFP4 quantization. Only E4M3 is supported.
Args:
input (torch.Tensor): The input tensor to be quantized.
block_sizes (dict | None): The block sizes for quantization.
"""
def cast_fp4(x):
sign = torch.sign(x)
sign_bit = (2 - sign) // 2
ord_ = torch.sum(
(x.abs().unsqueeze(-1) - cls.E2M1_bounds.to(x.device)) > 0, dim=-1
)
fp4_val = (sign_bit * 0b1000 + ord_).to(torch.uint8)
return fp4_val
def fuse_uint4_to_uint8(x):
# If the last dimension is odd, pad with zeros
# If this behavior is not desired, please modify the code accordingly
left_side = x[..., 0::2] # Even indices (0, 2, 4...)
right_side = x[..., 1::2] # Odd indices (1, 3, 5...)
new_data = (
right_side.clone() << 4
) # Put odd indices (higher addresses) in high bits
new_data[
..., : left_side.shape[-1]
] += left_side # Put even indices in low bits
return new_data
if block_size is None:
block_size = 32
original_shape = input.shape
original_dtype = input.dtype
input = input.view(-1, block_size)
# get scales
input_amax = input.abs().max(dim=-1, keepdim=True).values
descale = input_amax / cls.E2M1_max
min_value = torch.tensor(-127.0, device=descale.device)
e8m0_scale = torch.ceil(torch.maximum(torch.log2(descale), min_value))
input = (input / torch.exp2(e8m0_scale)).view(original_shape)
input_q = cast_fp4(input)
input_q = fuse_uint4_to_uint8(input_q)
e8m0_scale = (e8m0_scale + 127).to(torch.uint8)
return cls(original_shape, original_dtype, input_q), e8m0_scale
@classmethod
def dequantize(cls, quantized_data, dtype: torch.dtype, scale, block_sizes):
"""Dequantze MXFP4 packed tensor to a target dtype."""
def unfuse_uint8_to_uint4(x):
"""Unfuse uint8 values back to uint4 values.
This is the inverse operation of fuse_uint4_to_uint8.
"""
# Extract the lower 4 bits (even indices)
left_side = x & 0x0F
# Extract the upper 4 bits (odd indices)
right_side = (x >> 4) & 0x0F
# Create a new tensor with alternating values
shape = list(x.shape)
shape[-1] = shape[-1] * 2
result = torch.zeros(shape, dtype=torch.uint8, device=x.device)
# Fill in the values - even indices get low bits, odd indices get high bits
result[..., 0::2] = left_side # Even indices from low bits
result[..., 1::2] = right_side # Odd indices from high bits
return result
e8m0_scale = scale
block_size = block_sizes[-1]
# Unfuse the uint8 values back to uint4
x_unfused = unfuse_uint8_to_uint4(quantized_data)
# Extract sign and magnitude
sign = 1 - 2 * ((x_unfused & 0b1000) >> 3).to(
torch.float32
) # Extract sign bit and convert to +1/-1
magnitude = x_unfused & 0b0111 # Extract magnitude bits
magnitude = magnitude.to(torch.long)
# Create a tensor with the E2M1 values
values = torch.tensor(cls.E2M1_values, device=quantized_data.device)
# Use gather to index the values tensor properly
# We need to reshape magnitude to match the dimensions we want to gather along
original_shape = magnitude.shape
x_float = values[magnitude.reshape(-1)].reshape(original_shape)
# Apply sign and scale
x_float = sign.float() * x_float
# Reshape to apply block-wise scaling
x_float = x_float.reshape(-1, block_size)
# Apply the E8M0 scale
scale_factor = torch.exp2(e8m0_scale.float() - 127)
scale_factor = scale_factor.reshape(-1, 1) # Reshape for proper broadcasting
# Apply scaling and reshape back to original shape
x_float = x_float * scale_factor
# Reshape back to the original shape
return x_float.reshape(original_shape).to(dtype)
......@@ -126,17 +126,23 @@ class UnquantizedLinearMethod(LinearMethodBase):
class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
"""MoE method without quantization."""
def __init__(self, use_triton_kernels: bool = False):
def __init__(self, use_triton_kernels: bool = False, with_bias: bool = False):
super().__init__()
self.use_triton_kernels = use_triton_kernels
self.with_bias = with_bias
self.triton_kernel_moe_forward = None
self.triton_kernel_moe_with_bias_forward = None
if torch.cuda.is_available() and has_triton_kernels:
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_forward as _tk_forward,
)
from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
triton_kernel_moe_with_bias_forward as _tk_with_bias_forward,
)
self.triton_kernel_moe_forward = _tk_forward
self.triton_kernel_moe_with_bias_forward = _tk_with_bias_forward
def create_weights(
self,
......@@ -158,6 +164,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
if self.with_bias:
w13_weight_bias = torch.nn.Parameter(
torch.empty(num_experts, 2 * intermediate_size, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_bias", w13_weight_bias)
set_weight_attrs(w13_weight_bias, extra_weight_attrs)
# down_proj (row parallel)
w2_weight_n, w2_weight_k = (
hidden_size,
......@@ -172,6 +186,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
if self.with_bias:
w2_weight_bias = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w2_weight_bias", w2_weight_bias)
set_weight_attrs(w2_weight_bias, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if _use_aiter:
layer.w13_weight = torch.nn.Parameter(
......@@ -202,7 +224,14 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
) -> torch.Tensor:
kwargs = {}
if activation_alpha is not None:
kwargs["activation_alpha"] = activation_alpha
if swiglu_limit is not None:
kwargs["swiglu_limit"] = swiglu_limit
return self.forward(
x=x,
......@@ -213,6 +242,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
inplace=inplace,
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
**kwargs,
)
def forward_cuda(
......@@ -226,15 +256,30 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
inplace: bool = True,
no_combine: bool = False,
routed_scaling_factor: Optional[float] = None,
activation_alpha: Optional[float] = None,
swiglu_limit: Optional[float] = None,
) -> torch.Tensor:
if self.use_triton_kernels:
return self.triton_kernel_moe_forward(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_output=topk_output,
)
if self.with_bias:
return self.triton_kernel_moe_with_bias_forward(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
b1=layer.w13_weight_bias,
b2=layer.w2_weight_bias,
topk_output=topk_output,
activation=activation,
activation_alpha=activation_alpha,
swiglu_limit=swiglu_limit,
)
else:
return self.triton_kernel_moe_forward(
hidden_states=x,
w1=layer.w13_weight,
w2=layer.w2_weight,
topk_output=topk_output,
)
else:
if _use_aiter:
assert not no_combine, "unsupported"
......
......@@ -917,8 +917,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
is_hybrid = False
if isinstance(token_to_kv_pool_allocator, SWATokenToKVPoolAllocator):
assert isinstance(tree_cache, SWARadixCache) or isinstance(
tree_cache, SWAChunkCache
assert (
tree_cache is None
or isinstance(tree_cache, SWARadixCache)
or isinstance(tree_cache, SWAChunkCache)
), "SWARadixCache or SWAChunkCache is required for SWATokenToKVPoolAllocator"
is_hybrid = True
......
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Inference-only GptOss model compatible with HuggingFace weights."""
import logging
from collections.abc import Iterable
from functools import partial
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers import PretrainedConfig
from sglang.srt.distributed import (
get_moe_tensor_parallel_rank,
get_pp_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce,
)
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
from sglang.srt.layers.dp_attention import (
get_attention_tp_rank,
get_attention_tp_size,
get_local_attention_dp_size,
)
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import (
QKVParallelLinear,
ReplicatedLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.moe.topk import TopK
from sglang.srt.layers.moe.utils import DeepEPMode
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8_utils import dequant_mxfp4
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.layers.rotary_embedding import get_rope
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
from sglang.srt.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from sglang.srt.managers.schedule_batch import global_server_args_dict
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.utils import add_prefix, make_layers
class GptOssConfig(PretrainedConfig):
model_type = "gpt_oss"
def __init__(self, **kwargs):
super().__init__(**kwargs)
logger = logging.getLogger(__name__)
# Aligned with HF's implementation, using sliding window inclusive with the last token
# SGLang assumes exclusive
def get_attention_sliding_window_size(config):
return config.sliding_window - 1
class GptOssSparseMoeBlock(nn.Module):
def __init__(
self,
layer_id: int,
config: GptOssConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
self.layer_id = layer_id
self.activation = config.hidden_act
self.activation_alpha = getattr(config, "hidden_act_alpha", 1.702)
self.swiglu_limit = config.swiglu_limit
if self.tp_size > config.num_local_experts:
raise ValueError(
f"Tensor parallel size {self.tp_size} is greater than "
f"the number of experts {config.num_local_experts}."
)
self.topk = TopK(
top_k=config.num_experts_per_tok,
renormalize=True,
)
experts_type = get_moe_impl_class()
extra_kwargs = {}
if experts_type.__name__ == "FusedMoE":
extra_kwargs = {
"enable_flashinfer_cutlass_moe": global_server_args_dict[
"enable_flashinfer_cutlass_moe"
],
"use_weight_loader_fused": True, # for moe gate_up_proj and down_proj and their bias loading
}
self.experts = experts_type(
num_experts=config.num_local_experts
+ global_server_args_dict["ep_num_redundant_experts"],
top_k=config.num_experts_per_tok,
layer_id=layer_id,
hidden_size=config.hidden_size,
intermediate_size=config.intermediate_size,
quant_config=quant_config,
activation=self.activation,
activation_alpha=self.activation_alpha,
swiglu_limit=self.swiglu_limit,
with_bias=True,
prefix=add_prefix("experts", prefix),
**(
dict(deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]])
if global_server_args_dict["moe_a2a_backend"].is_deepep()
else {}
),
**extra_kwargs,
)
self.router = ReplicatedLinear(
config.hidden_size,
config.num_local_experts,
bias=True,
quant_config=None,
prefix=add_prefix("gate", prefix),
params_dtype=config.torch_dtype,
)
def forward(
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
) -> torch.Tensor:
if not global_server_args_dict["moe_a2a_backend"].is_deepep():
return self.forward_normal(hidden_states)
else:
raise Exception("forward_deepep branch not implemented yet")
def get_moe_weights(self):
return [
x.data
for name, x in self.experts.named_parameters()
if name not in ["correction_bias"]
]
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
num_tokens, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
# router_logits: (num_tokens, n_experts)
router_logits, _ = self.router(hidden_states)
kwargs = {"hidden_states": hidden_states}
if self.topk is not None:
kwargs["topk_output"] = self.topk(hidden_states, router_logits)
else:
kwargs["router_logits"] = router_logits
final_hidden_states = self.experts(**kwargs)
if self.tp_size > 1:
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
ans = final_hidden_states.view(num_tokens, hidden_dim)
return ans
class GptOssAttention(nn.Module):
def __init__(
self,
hidden_size: int,
num_heads: int,
num_kv_heads: int,
layer_id: int = 0,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
head_dim: Optional[int] = None,
rms_norm_eps: float = 1e-06,
attention_bias: bool = False,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
sliding_window_size: int = -1, # if -1, normal attention, else, window attention.
layer_type: str = "",
params_dtype: torch.dtype = torch.bfloat16,
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.sliding_window_size = sliding_window_size
attn_tp_rank = get_attention_tp_rank()
attn_tp_size = get_attention_tp_size()
self.total_num_heads = num_heads
assert self.total_num_heads % attn_tp_size == 0
self.num_heads = self.total_num_heads // attn_tp_size
self.total_num_kv_heads = num_kv_heads
if self.total_num_kv_heads >= attn_tp_size:
# Number of KV heads is greater than TP size, so we partition
# the KV heads across multiple tensor parallel GPUs.
assert self.total_num_kv_heads % attn_tp_size == 0
else:
# Number of KV heads is less than TP size, so we replicate
# the KV heads across multiple tensor parallel GPUs.
assert attn_tp_size % self.total_num_kv_heads == 0
self.num_kv_heads = max(1, self.total_num_kv_heads // attn_tp_size)
self.head_dim = head_dim or hidden_size // self.total_num_heads
self.q_size = self.num_heads * self.head_dim
self.kv_size = self.num_kv_heads * self.head_dim
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
self.tp_rank = get_tensor_model_parallel_rank()
self.qkv_proj = QKVParallelLinear(
hidden_size,
self.head_dim,
self.total_num_heads,
self.total_num_kv_heads,
bias=attention_bias,
params_dtype=params_dtype,
quant_config=quant_config,
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
prefix=add_prefix("qkv_proj", prefix),
)
self.sinks = nn.Parameter(
torch.empty(self.num_heads, dtype=params_dtype), requires_grad=False
)
self.o_proj = RowParallelLinear(
self.total_num_heads * self.head_dim,
hidden_size,
bias=attention_bias,
quant_config=quant_config,
tp_rank=attn_tp_rank,
tp_size=attn_tp_size,
reduce_results=False,
params_dtype=params_dtype,
prefix=add_prefix("o_proj", prefix),
)
self.rotary_emb = get_rope(
self.head_dim,
rotary_dim=self.head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
)
assert layer_type in {"sliding_attention", "full_attention"}
use_sliding_window = layer_type == "sliding_attention"
self.attn = RadixAttention(
self.num_heads,
self.head_dim,
self.scaling,
num_kv_heads=self.num_kv_heads,
layer_id=layer_id,
prefix=add_prefix("attn", prefix),
sliding_window_size=(sliding_window_size if use_sliding_window else -1),
)
self.layer_id = layer_id
def forward_prepare(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
):
if hidden_states.shape[0] == 0:
return hidden_states, forward_batch, None
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
q, k = self.rotary_emb(positions, q, k)
inner_state = q, k, v, forward_batch
return None, forward_batch, inner_state
def forward_core(self, intermediate_state):
hidden_states, forward_batch, inner_state = intermediate_state
if inner_state is None:
return hidden_states
attn_output = self.attn(*inner_state, sk=self.sinks)
output, _ = self.o_proj(attn_output)
return output
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
) -> torch.Tensor:
s = self.forward_prepare(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
return self.forward_core(s)
class GptOssDecoderLayer(nn.Module):
def __init__(
self,
config: GptOssConfig,
layer_id: int,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
sliding_window_size: int | None = None,
) -> None:
super().__init__()
self.config = config
self.hidden_size = config.hidden_size
rope_theta = getattr(config, "rope_theta", 10000)
rope_scaling = getattr(config, "rope_scaling", None)
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
head_dim = getattr(
config, "head_dim", config.hidden_size // config.num_attention_heads
)
rms_norm_eps = config.rms_norm_eps
attention_bias = config.attention_bias
if sliding_window_size is None:
self.sliding_window_size = get_attention_sliding_window_size(self.config)
else:
self.sliding_window_size = sliding_window_size
self.self_attn = GptOssAttention(
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
num_kv_heads=config.num_key_value_heads,
layer_id=layer_id,
rope_theta=rope_theta,
rope_scaling=rope_scaling,
max_position_embeddings=max_position_embeddings,
head_dim=head_dim,
rms_norm_eps=rms_norm_eps,
attention_bias=attention_bias,
quant_config=quant_config,
prefix=add_prefix("self_attn", prefix),
sliding_window_size=self.sliding_window_size,
layer_type=config.layer_types[layer_id],
params_dtype=config.torch_dtype,
)
self.layer_id = layer_id
self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank()
self.local_dp_size = get_local_attention_dp_size()
# GptOss all layers are sparse and have no nextn now
self.is_layer_sparse = True
is_previous_layer_sparse = True
self.layer_scatter_modes = LayerScatterModes.init_new(
layer_id=layer_id,
num_layers=config.num_hidden_layers,
is_layer_sparse=self.is_layer_sparse,
is_previous_layer_sparse=is_previous_layer_sparse,
)
if self.is_layer_sparse:
self.mlp = GptOssSparseMoeBlock(
layer_id=self.layer_id,
config=config,
quant_config=quant_config,
prefix=add_prefix("mlp", prefix),
)
else:
raise NotImplementedError(
"Dense MLP is not implemented for GptOssDecoderLayer. "
"Please use GptOssSparseMoeBlock instead."
)
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = RMSNorm(
config.hidden_size, eps=config.rms_norm_eps
)
self.layer_communicator = LayerCommunicator(
layer_scatter_modes=self.layer_scatter_modes,
input_layernorm=self.input_layernorm,
post_attention_layernorm=self.post_attention_layernorm,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
forward_batch: ForwardBatch,
residual: Optional[torch.Tensor],
) -> Tuple[torch.Tensor, torch.Tensor]:
hidden_states, residual = self.layer_communicator.prepare_attn(
hidden_states, residual, forward_batch
)
if hidden_states.shape[0] != 0:
hidden_states = self.self_attn(
positions=positions,
hidden_states=hidden_states,
forward_batch=forward_batch,
)
hidden_states, residual = self.layer_communicator.prepare_mlp(
hidden_states, residual, forward_batch
)
hidden_states = self.mlp(hidden_states, forward_batch)
hidden_states, residual = self.layer_communicator.postprocess_layer(
hidden_states, residual, forward_batch
)
return hidden_states, residual
class GptOssModel(nn.Module):
def __init__(
self,
config: PretrainedConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
decoder_layer_type: type[nn.Module] = GptOssDecoderLayer,
) -> None:
super().__init__()
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.pp_group = get_pp_group()
if self.pp_group.is_first_rank:
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
enable_tp=not global_server_args_dict["enable_dp_attention"],
prefix=add_prefix("embed_tokens", prefix),
)
else:
self.embed_tokens = PPMissingLayer()
# Use the provided decoder layer type or default to GptOssDecoderLayer
decoder_layer_type = decoder_layer_type or GptOssDecoderLayer
self.layers, self.start_layer, self.end_layer = make_layers(
config.num_hidden_layers,
lambda idx, prefix: decoder_layer_type(
layer_id=idx,
config=config,
quant_config=quant_config,
prefix=prefix,
),
pp_rank=self.pp_group.rank_in_group,
pp_size=self.pp_group.world_size,
prefix=add_prefix("layers", prefix),
)
if self.pp_group.is_last_rank:
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
else:
self.norm = PPMissingLayer(return_tuple=True)
self.layers_to_capture = []
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> Union[torch.Tensor, PPProxyTensors]:
if self.pp_group.is_first_rank:
if input_embeds is None:
hidden_states = self.embed_tokens(input_ids)
else:
hidden_states = input_embeds
residual = None
else:
assert pp_proxy_tensors is not None
hidden_states = pp_proxy_tensors["hidden_states"]
residual = pp_proxy_tensors["residual"]
aux_hidden_states = []
for i in range(self.start_layer, self.end_layer):
with get_global_expert_distribution_recorder().with_current_layer(i):
if i in self.layers_to_capture:
aux_hidden_states.append(hidden_states + residual)
layer = self.layers[i]
hidden_states, residual = layer(
positions, hidden_states, forward_batch, residual
)
if not self.pp_group.is_last_rank:
return PPProxyTensors(
{
"hidden_states": hidden_states,
"residual": residual,
}
)
else:
if hidden_states.shape[0] != 0:
if residual is None:
hidden_states = self.norm(hidden_states)
else:
hidden_states, _ = self.norm(hidden_states, residual)
if len(aux_hidden_states) == 0:
return hidden_states
return hidden_states, aux_hidden_states
class GptOssForCausalLM(nn.Module):
fall_back_to_pt_during_load = False
def __init__(
self,
config: GptOssConfig,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.pp_group = get_pp_group()
self.config = config
self.quant_config = quant_config
self.model = GptOssModel(
config, quant_config, prefix=add_prefix("model", prefix)
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=add_prefix("lm_head", prefix),
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
)
self.logits_processor = LogitsProcessor(config)
self.capture_aux_hidden_states = False
@torch.no_grad()
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
forward_batch: ForwardBatch,
input_embeds: torch.Tensor = None,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> torch.Tensor:
hidden_states = self.model(
input_ids,
positions,
forward_batch,
input_embeds,
pp_proxy_tensors=pp_proxy_tensors,
)
aux_hidden_states = None
if self.capture_aux_hidden_states:
hidden_states, aux_hidden_states = hidden_states
if self.pp_group.is_last_rank:
return self.logits_processor(
input_ids,
hidden_states,
self.lm_head,
forward_batch,
aux_hidden_states,
)
else:
return hidden_states
@property
def start_layer(self):
return self.model.start_layer
@property
def end_layer(self):
return self.model.end_layer
def _get_default_weight_mapping(self):
"""Generate default weight name mapping for GptOss safetensors."""
weight_mapping = {}
# Map router weights to gate
weight_mapping["embedding.weight"] = "model.embed_tokens.weight"
weight_mapping["unembedding.weight"] = "lm_head.weight"
weight_mapping["norm.scale"] = "model.norm.weight"
for layer_id in range(self.config.num_hidden_layers):
weight_mapping[f"block.{layer_id}.attn.q_proj.weight"] = (
f"model.layers.{layer_id}.self_attn.q_proj.weight"
)
weight_mapping[f"block.{layer_id}.attn.q_proj.bias"] = (
f"model.layers.{layer_id}.self_attn.q_proj.bias"
)
weight_mapping[f"block.{layer_id}.attn.k_proj.weight"] = (
f"model.layers.{layer_id}.self_attn.k_proj.weight"
)
weight_mapping[f"block.{layer_id}.attn.k_proj.bias"] = (
f"model.layers.{layer_id}.self_attn.k_proj.bias"
)
weight_mapping[f"block.{layer_id}.attn.v_proj.weight"] = (
f"model.layers.{layer_id}.self_attn.v_proj.weight"
)
weight_mapping[f"block.{layer_id}.attn.v_proj.bias"] = (
f"model.layers.{layer_id}.self_attn.v_proj.bias"
)
weight_mapping[f"block.{layer_id}.attn.out.weight"] = (
f"model.layers.{layer_id}.self_attn.o_proj.weight"
)
weight_mapping[f"block.{layer_id}.attn.out.bias"] = (
f"model.layers.{layer_id}.self_attn.o_proj.bias"
)
weight_mapping[f"block.{layer_id}.attn.sinks"] = (
f"model.layers.{layer_id}.self_attn.sinks"
)
weight_mapping[f"block.{layer_id}.attn.norm.scale"] = (
f"model.layers.{layer_id}.input_layernorm.weight"
)
weight_mapping[f"block.{layer_id}.mlp.gate.weight"] = (
f"model.layers.{layer_id}.mlp.router.weight"
)
weight_mapping[f"block.{layer_id}.mlp.gate.bias"] = (
f"model.layers.{layer_id}.mlp.router.bias"
)
weight_mapping[f"block.{layer_id}.mlp.norm.scale"] = (
f"model.layers.{layer_id}.post_attention_layernorm.weight"
)
weight_mapping[f"block.{layer_id}.mlp.experts.gate_up_proj"] = (
f"model.layers.{layer_id}.mlp.experts.gate_up_proj"
)
weight_mapping[f"block.{layer_id}.mlp.gate_up_proj_bias"] = (
f"model.layers.{layer_id}.mlp.experts.gate_up_proj_bias"
)
weight_mapping[f"block.{layer_id}.mlp.down_proj"] = (
f"model.layers.{layer_id}.mlp.experts.mlp2_weight"
)
weight_mapping[f"block.{layer_id}.mlp.down_proj_bias"] = (
f"model.layers.{layer_id}.mlp.experts.mlp2_bias"
)
return weight_mapping
def load_weights(
self,
weights: Iterable[Tuple[str, torch.Tensor]],
is_nextn: bool = False,
weight_name_mapping: dict = None,
):
tp_rank = get_tensor_model_parallel_rank()
if is_nextn:
logging.warning(
"Loading weights for nextn is currently not supported in GptOssForCausalLM. "
)
return
weights = _canonicalize_weights(self.config, weights)
weights = sorted(weights, key=lambda x: x[0]) # Sort by name for consistency
new_weights = []
for name, p in weights:
if "qkv.weight" in name:
q_proj, k_proj, v_proj = p.split(
[
self.config.num_attention_heads * self.config.head_dim,
self.config.num_key_value_heads * self.config.head_dim,
self.config.num_key_value_heads * self.config.head_dim,
],
dim=0,
)
new_weights.append(
(f"{name.replace('qkv.weight', 'q_proj.weight')}", q_proj)
)
new_weights.append(
(f"{name.replace('qkv.weight', 'k_proj.weight')}", k_proj)
)
new_weights.append(
(f"{name.replace('qkv.weight', 'v_proj.weight')}", v_proj)
)
elif "qkv.bias" in name:
q_bias, k_bias, v_bias = p.split(
[
self.config.num_attention_heads * self.config.head_dim,
self.config.num_key_value_heads * self.config.head_dim,
self.config.num_key_value_heads * self.config.head_dim,
],
dim=0,
)
new_weights.append(
(f"{name.replace('qkv.bias', 'q_proj.bias')}", q_bias)
)
new_weights.append(
(f"{name.replace('qkv.bias', 'k_proj.bias')}", k_bias)
)
new_weights.append(
(f"{name.replace('qkv.bias', 'v_proj.bias')}", v_bias)
)
else:
new_weights.append((name, p))
weights = new_weights
# Use provided weight name mapping if available, otherwise use default
if weight_name_mapping is None:
weight_name_mapping = self._get_default_weight_mapping()
else:
# Merge with default mapping
default_mapping = self._get_default_weight_mapping()
default_mapping.update(weight_name_mapping)
weight_name_mapping = default_mapping
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
]
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping_fused(
ckpt_gate_up_proj_name="gate_up_proj",
ckpt_down_proj_name="down_proj",
ckpt_gate_up_proj_bias_name="gate_up_proj_bias",
ckpt_down_proj_bias_name="down_proj_bias",
)
params_dict = dict(self.named_parameters())
params_checker = {k: False for k, v in params_dict.items()}
for name, loaded_weight in weights:
loaded_weight = _WeightCreator.maybe_materialize(loaded_weight)
# Apply weight name mapping if provided
if weight_name_mapping and name in weight_name_mapping:
name = weight_name_mapping[name]
layer_id = get_layer_id(name)
if (
layer_id is not None
and hasattr(self.model, "start_layer")
and (
layer_id < self.model.start_layer
or layer_id >= self.model.end_layer
)
):
continue
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
params_checker[name] = True
break
else:
for mapping in expert_params_mapping:
param_name, weight_name, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
if "bias" not in name:
loaded_weight = loaded_weight.transpose(-2, -1)
if "w2_weight_bias" in name and get_moe_tensor_parallel_rank() != 0:
loaded_weight = loaded_weight.zero_()
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
)
params_checker[name] = True
break
else:
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
if name in params_dict.keys():
param = params_dict[name]
if "sinks" in name:
start = tp_rank * param.numel()
param.data.copy_(
loaded_weight[start : start + param.numel()]
)
else:
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
params_checker[name] = True
else:
logger.warning(f"Parameter {name} not found in params_dict")
not_loaded_params = [k for k, v in params_checker.items() if not v]
if tp_rank == 0:
if len(not_loaded_params) > 0:
raise Exception(f"Not all parameters loaded: {not_loaded_params}")
else:
logging.info("All parameters loaded successfully.")
self.routed_experts_weights_of_layer = {
layer_id: self.model.layers[layer_id].mlp.get_moe_weights()
for layer_id in range(self.start_layer, self.end_layer)
if isinstance(self.model.layers[layer_id].mlp, GptOssSparseMoeBlock)
}
def get_embed_and_head(self):
return self.model.embed_tokens.weight, self.lm_head.weight
def set_embed_and_head(self, embed, head):
del self.model.embed_tokens.weight
del self.lm_head.weight
self.model.embed_tokens.weight = embed
self.lm_head.weight = head
torch.cuda.empty_cache()
torch.cuda.synchronize()
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
if not self.pp_group.is_last_rank:
return
if layer_ids is None:
self.capture_aux_hidden_states = True
num_layers = self.config.num_hidden_layers
self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
else:
self.capture_aux_hidden_states = True
# we plus 1 here because in sglang, for the ith layer, it takes the output
# of the (i-1)th layer as aux hidden state
self.model.layers_to_capture = [val + 1 for val in layer_ids]
@classmethod
def get_model_config_for_expert_location(cls, config):
return ModelConfigForExpertLocation(
num_layers=config.num_hidden_layers,
num_logical_experts=config.num_local_experts,
num_groups=None,
)
def get_attention_sliding_window_size(self):
return get_attention_sliding_window_size(self.config)
def _canonicalize_weights(config, weights_in: Iterable[Tuple[str, torch.Tensor]]):
weights_out_dict = dict(weights_in)
for layer_id in range(config.num_hidden_layers):
for name_chunk in ["mlp1_weight", "mlp2_weight"]:
name_prefix = f"block.{layer_id}.mlp.{name_chunk}"
w_blocks = weights_out_dict.pop(f"{name_prefix}.blocks", None)
w_scales = weights_out_dict.pop(f"{name_prefix}.scales", None)
if w_blocks is not None:
weights_out_dict[name_prefix] = _WeightCreator(
partial(
_dequant_mlp_weight,
debug_name=name_prefix,
w_blocks=w_blocks,
w_scales=w_scales,
)
)
return list(weights_out_dict.items())
def _dequant_mlp_weight(debug_name, w_blocks, w_scales):
if get_tensor_model_parallel_rank() == 0:
logger.info(f"Dequantize {debug_name} start")
original_device = w_blocks.device
w_blocks = w_blocks.cuda()
w_scales = w_scales.cuda()
w_bf16 = dequant_mxfp4(w_block=w_blocks, w_scale=w_scales, out_dtype=torch.bfloat16)
w_bf16 = w_bf16.transpose(-2, -1).contiguous()
if get_tensor_model_parallel_rank() == 0:
logger.info(
f"Dequantize {debug_name} end {w_blocks.shape=} {w_scales.shape=} {w_bf16.shape=}"
)
return w_bf16.to(original_device)
class _WeightCreator:
def __init__(self, fn):
self._fn = fn
@staticmethod
def maybe_materialize(obj):
if isinstance(obj, _WeightCreator):
output = obj._fn()
obj._fn = None
return output
return obj
EntryClass = GptOssForCausalLM
......@@ -457,6 +457,10 @@ class ServerArgs:
raise ValueError(
"trtllm_mla backend does not support speculative decoding yet."
)
model_arch = self.get_hf_config().architectures[0]
if model_arch in ["GptOssForCausalLM"]:
self.attention_backend = "triton"
self.enable_triton_kernel_moe = True
# Set page size
if self.page_size is None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment