Unverified Commit b819381f authored by HAI's avatar HAI Committed by GitHub
Browse files

AITER backend extension and workload optimizations (#6838)


Co-authored-by: default avatarwunhuang <wunhuang@amd.com>
Co-authored-by: default avatarHubert Lu <Hubert.Lu@amd.com>
parent 562f279a
......@@ -72,7 +72,7 @@ jobs:
- name: Evaluate accuracy (TP=2)
timeout-minutes: 30
run: |
bash scripts/amd_ci_exec.sh python3 test_moe_eval_accuracy_large.py
bash scripts/amd_ci_exec.sh -e SGLANG_USE_AITER=0 python3 test_moe_eval_accuracy_large.py
mla-test-1-gpu-amd:
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
......
......@@ -53,7 +53,7 @@ SGLang supports various environment variables that can be used to configure its
| Environment Variable | Description | Default Value |
| --- | --- | --- |
| `SGLANG_AITER_MOE` | Use AITER MOE implementation | `false` |
| `SGLANG_USE_AITER` | Use AITER optimize implementation | `false` |
| `SGLANG_INT4_WEIGHT` | Enable INT4 weight quantization | `false` |
| `SGLANG_MOE_PADDING` | Enable MoE padding (sets padding size to 128 if value is `1`, often set to `1` in Docker builds) | `0` |
| `SGLANG_FORCE_FP8_MARLIN` | Force using FP8 MARLIN kernels even if other FP8 kernels are available | `false` |
......
......@@ -27,12 +27,19 @@ if TYPE_CHECKING:
from sglang.srt.speculative.spec_info import SpecInfo
try:
from aiter import mha_batch_prefill_func, paged_attention_ragged
from aiter import (
flash_attn_varlen_func,
mha_batch_prefill_func,
paged_attention_ragged,
)
from aiter.mla import mla_decode_fwd
except ImportError:
print(
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
)
from sglang.srt.configs.model_config import AttentionArch
class WrapperDispatch(Enum):
SLIDING_WINDOW = auto()
......@@ -43,6 +50,10 @@ class WrapperDispatch(Enum):
class ForwardMetadata:
kv_indptr: torch.Tensor
kv_indices: torch.Tensor
qo_indptr: torch.Tensor
kv_last_page_len: torch.Tensor
max_extend_len: int
max_prefix_extend_len: int
max_q_len: int
max_kv_len: int
......@@ -63,6 +74,7 @@ class AiterAttnBackend(AttentionBackend):
self.device = model_runner.device
self.is_multimodal = model_runner.model_config.is_multimodal
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
self.num_head = (
model_runner.model_config.num_attention_heads // get_attention_tp_size()
)
......@@ -75,6 +87,8 @@ class AiterAttnBackend(AttentionBackend):
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
# Parse constants
self.max_context_len = model_runner.model_config.context_len
self.skip_prefill = skip_prefill
......@@ -100,6 +114,10 @@ class AiterAttnBackend(AttentionBackend):
self.indices_updater_prefill = AiterIndicesUpdaterPrefill(
model_runner, self
)
if self.use_mla:
self.mla_indices_updater_prefill = AiterMlaIndicesUpdaterPrefill(
model_runner, self
)
# aiter kernel related initialization
self.max_num_partitions = (
......@@ -108,33 +126,40 @@ class AiterAttnBackend(AttentionBackend):
nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8
self.workspace_buffer = torch.empty(
(max_bs * self.num_head * self.max_num_partitions * self.head_dim)
* nbyes_per_qo_elem
+ 2 * (max_bs * self.num_head * self.max_num_partitions) * 4,
dtype=torch.uint8,
device=self.device,
)
if not self.use_mla:
self.workspace_buffer = torch.empty(
(max_bs * self.num_head * self.max_num_partitions * self.head_dim)
* nbyes_per_qo_elem
+ 2 * (max_bs * self.num_head * self.max_num_partitions) * 4,
dtype=torch.uint8,
device=self.device,
)
self.scale = float(1.0 / (self.head_dim**0.5))
self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to(
self.device
)
self.kv_last_page_lens = torch.ones((max_bs,), dtype=torch.int32).to(
self.device
)
self.logits_soft_cap = 0.0
self.forward_metadata: ForwardMetadata = None
if self.use_mla:
self.qo_indptr_ = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init auxiliary variables for triton attention backend."""
bs = forward_batch.batch_size
kv_indptr = self.kv_indptr
spec_info = forward_batch.spec_info
qo_indptr = None
kv_last_page_len = None
max_extend_len = None
if forward_batch.forward_mode.is_decode_or_idle():
# update for aiter
# create kv_indices and kv_inptr
bs = forward_batch.batch_size
kv_indptr = self.kv_indptr
spec_info = forward_batch.spec_info
if spec_info is None:
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
......@@ -154,38 +179,103 @@ class AiterAttnBackend(AttentionBackend):
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1
self.forward_metadata = ForwardMetadata(kv_indptr, kv_indices, None, None)
if self.use_mla:
qo_indptr = self.qo_indptr_[: bs + 1]
qo_indptr[1 : bs + 1] = torch.cumsum(self.kv_last_page_len[:bs], dim=0)
kv_last_page_len = self.kv_last_page_len[:bs]
max_extend_len = 1
elif forward_batch.forward_mode.is_draft_extend():
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens=None,
encoder_lens=forward_batch.encoder_lens,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indptr,
self.indices_updater_prefill.kv_indices,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
kv_indptr,
kv_indices,
qo_indptr,
kv_last_page_len,
max_extend_len,
None,
None,
None,
)
elif forward_batch.forward_mode.is_draft_extend():
if self.use_mla:
prefix_lens = forward_batch.extend_prefix_lens
self.mla_indices_updater_prefill.update(
forward_batch.req_pool_indices,
prefix_lens,
prefix_lens.sum().item(),
forward_batch.extend_seq_lens,
encoder_lens=forward_batch.encoder_lens,
spec_info=None,
)
self.forward_metadata = ForwardMetadata(
self.mla_indices_updater_prefill.kv_indptr,
self.mla_indices_updater_prefill.kv_indices,
self.mla_indices_updater_prefill.qo_indptr,
self.mla_indices_updater_prefill.kv_last_page_len,
self.mla_indices_updater_prefill.max_extend_len,
self.mla_indices_updater_prefill.max_prefix_extend_len,
None,
None,
)
else:
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens=None,
encoder_lens=forward_batch.encoder_lens,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indptr,
self.indices_updater_prefill.kv_indices,
None,
None,
None,
None,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
)
elif forward_batch.forward_mode.is_target_verify():
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens=None,
encoder_lens=forward_batch.encoder_lens,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indptr,
self.indices_updater_prefill.kv_indices,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
)
if self.use_mla:
prefix_lens = forward_batch.extend_prefix_lens
self.mla_indices_updater_prefill.update(
forward_batch.req_pool_indices,
prefix_lens,
prefix_lens.sum().item(),
forward_batch.extend_seq_lens,
encoder_lens=forward_batch.encoder_lens,
spec_info=None,
)
self.forward_metadata = ForwardMetadata(
self.mla_indices_updater_prefill.kv_indptr,
self.mla_indices_updater_prefill.kv_indices,
self.mla_indices_updater_prefill.qo_indptr,
self.mla_indices_updater_prefill.kv_last_page_len,
self.mla_indices_updater_prefill.max_extend_len,
self.mla_indices_updater_prefill.max_prefix_extend_len,
None,
None,
)
else:
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens=None,
encoder_lens=forward_batch.encoder_lens,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indptr,
self.indices_updater_prefill.kv_indices,
None,
None,
None,
None,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
)
else:
prefix_lens = forward_batch.extend_prefix_lens
......@@ -194,24 +284,49 @@ class AiterAttnBackend(AttentionBackend):
else:
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens,
encoder_lens=forward_batch.encoder_lens,
spec_info=None,
)
self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indptr,
self.indices_updater_prefill.kv_indices,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
)
if self.use_mla:
self.mla_indices_updater_prefill.update(
forward_batch.req_pool_indices,
prefix_lens,
prefix_lens.sum().item(),
forward_batch.extend_seq_lens,
encoder_lens=forward_batch.encoder_lens,
spec_info=None,
)
self.forward_metadata = ForwardMetadata(
self.mla_indices_updater_prefill.kv_indptr,
self.mla_indices_updater_prefill.kv_indices,
self.mla_indices_updater_prefill.qo_indptr,
self.mla_indices_updater_prefill.kv_last_page_len,
self.mla_indices_updater_prefill.max_extend_len,
self.mla_indices_updater_prefill.max_prefix_extend_len,
None,
None,
)
else:
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens,
encoder_lens=forward_batch.encoder_lens,
spec_info=None,
)
self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indptr,
self.indices_updater_prefill.kv_indices,
None,
None,
None,
None,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
)
def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
):
self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int)
if kv_indices_buf is None:
self.cuda_graph_kv_indices = torch.zeros(
(max_bs * self.max_context_len),
......@@ -239,6 +354,10 @@ class AiterAttnBackend(AttentionBackend):
spec_info: Optional[SpecInfo],
):
if forward_mode.is_decode_or_idle():
qo_indptr = None
kv_last_page_len = None
max_extend_len = None
if spec_info is None:
kv_indptr = self.kv_indptr
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
......@@ -255,25 +374,83 @@ class AiterAttnBackend(AttentionBackend):
)
else:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
self.forward_metadata = ForwardMetadata(kv_indptr, kv_indices, None, None)
elif forward_mode.is_target_verify():
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_prefill.update(
req_pool_indices,
seq_lens,
seq_lens_sum,
prefix_lens=None,
encoder_lens=encoder_lens,
spec_info=spec_info,
)
if self.use_mla:
qo_indptr = self.qo_indptr_[: bs + 1]
qo_indptr[1 : bs + 1] = torch.cumsum(
self.cuda_graph_kv_last_page_len[:bs], dim=0
)
max_extend_len = 1
kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indptr,
self.indices_updater_prefill.kv_indices,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
kv_indptr,
kv_indices,
qo_indptr,
kv_last_page_len,
max_extend_len,
None,
None,
None,
)
elif forward_mode.is_target_verify():
if self.use_mla:
qo_indptr = self.qo_indptr[: bs + 1]
qo_indptr[: bs + 1] = torch.arange(
0,
(1 + bs) * self.num_draft_tokens,
step=self.num_draft_tokens,
dtype=torch.int32,
device=self.device,
)
kv_indptr = self.kv_indptr[: bs + 1]
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
kv_indices = self.cuda_graph_kv_indices
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
max_extend_len = self.num_draft_tokens
kv_last_page_len = None
self.forward_metadata = ForwardMetadata(
kv_indptr,
kv_indices,
qo_indptr,
kv_last_page_len,
max_extend_len,
None,
None,
None,
)
else:
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_prefill.update(
req_pool_indices,
seq_lens,
seq_lens_sum,
prefix_lens=None,
encoder_lens=encoder_lens,
spec_info=spec_info,
)
self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indptr,
self.indices_updater_prefill.kv_indices,
None,
None,
None,
None,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
)
else:
raise ValueError(f"Invalid mode: {forward_mode=}")
......@@ -342,31 +519,113 @@ class AiterAttnBackend(AttentionBackend):
if k is not None:
assert v is not None
if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
bs0 = forward_batch.batch_size + 1
o = mha_batch_prefill_func(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache,
v_cache,
self.qo_indptr[:bs0],
self.forward_metadata.kv_indptr[:bs0],
self.forward_metadata.kv_indices,
self.forward_metadata.max_q_len,
self.forward_metadata.max_kv_len,
causal=True,
logits_soft_cap=self.logits_soft_cap,
alibi_slopes=None,
return_lse=False,
return_attn_probs=False,
)
if self.use_mla:
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
else:
forward_batch.token_to_kv_pool.set_kv_buffer(
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
if self.use_mla:
max_extend_len = self.forward_metadata.max_extend_len
max_prefix_extend_len = self.forward_metadata.max_prefix_extend_len
kv_indptr = self.forward_metadata.kv_indptr
kv_indices = self.forward_metadata.kv_indices
kv_last_page_lens = self.forward_metadata.kv_last_page_len
qo_indptr = self.forward_metadata.qo_indptr
K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
kv_lora_rank = V_Buffer.shape[-1]
qk_rope_head_dim = K_Buffer.shape[-1] - kv_lora_rank
qk_nope_head_dim = k.shape[-1] - qk_rope_head_dim
assert len(q.shape) == 3
assert len(k.shape) == 3
assert len(v.shape) == 3
if kv_indices.shape[0] == 0:
o = flash_attn_varlen_func(
q,
k,
v,
qo_indptr,
qo_indptr,
max_extend_len,
max_extend_len,
softmax_scale=layer.scaling,
causal=True,
)
return o
elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim):
K_Buffer = torch.index_select(K_Buffer, 0, kv_indices)
kvc, k_pe = torch.split(
K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1
)
kvprefix = layer.kv_b_proj(kvc.contiguous())[0]
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
kvprefix = kvprefix.view(
-1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim
)
k_prefix, v_prefix = torch.split(
kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1
)
k_prefix = torch.cat(
[
k_prefix,
torch.broadcast_to(
k_pe,
(k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]),
),
],
dim=-1,
)
assert (
forward_batch.extend_prefix_lens.shape
== forward_batch.extend_seq_lens.shape
)
k_prefix = torch.split(k_prefix, forward_batch.extend_prefix_lens_cpu)
k_extend = torch.split(k, forward_batch.extend_seq_lens_cpu)
assert len(k_prefix) == len(forward_batch.extend_prefix_lens_cpu)
k = torch.cat([x for el in zip(k_prefix, k_extend) for x in el])
v_prefix = torch.split(v_prefix, forward_batch.extend_prefix_lens_cpu)
v_extend = torch.split(v, forward_batch.extend_seq_lens_cpu)
v = torch.cat([x for el in zip(v_prefix, v_extend) for x in el])
o = flash_attn_varlen_func(
q,
k,
v,
qo_indptr,
kv_indptr,
max_extend_len,
max_prefix_extend_len,
softmax_scale=layer.scaling,
causal=True,
)
return o
else:
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
layer.layer_id
)
bs0 = forward_batch.batch_size + 1
o = mha_batch_prefill_func(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache,
v_cache,
self.qo_indptr[:bs0],
self.forward_metadata.kv_indptr[:bs0],
self.forward_metadata.kv_indices,
self.forward_metadata.max_q_len,
self.forward_metadata.max_kv_len,
causal=True,
logits_soft_cap=self.logits_soft_cap,
alibi_slopes=None,
return_lse=False,
return_attn_probs=False,
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def forward_decode(
self,
......@@ -377,6 +636,7 @@ class AiterAttnBackend(AttentionBackend):
forward_batch: ForwardBatch,
save_kv_cache=True,
):
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
if layer.qk_head_dim != layer.v_head_dim:
......@@ -389,32 +649,48 @@ class AiterAttnBackend(AttentionBackend):
layer, forward_batch.out_cache_loc, k, v
)
self.logits_soft_cap = layer.logit_cap
paged_attention_ragged(
o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
self.workspace_buffer,
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view(
-1, 1, layer.tp_k_head_num, layer.qk_head_dim
),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view(
-1, 1, layer.tp_v_head_num, layer.v_head_dim
),
self.scale,
self.forward_metadata.kv_indptr,
self.forward_metadata.kv_indices,
self.kv_last_page_lens,
1,
self.max_num_partitions,
None,
"auto",
"NHD",
self.logits_soft_cap,
self.k_scale,
self.v_scale,
None,
_AITER_PARTITION_SIZE_ROCM,
)
if self.use_mla:
k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
mla_decode_fwd(
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
k_buffer.view(-1, 1, 1, layer.qk_head_dim),
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
self.forward_metadata.qo_indptr,
self.forward_metadata.kv_indptr,
self.forward_metadata.kv_indices,
self.forward_metadata.kv_last_page_len,
self.forward_metadata.max_extend_len,
layer.scaling,
layer.logit_cap,
)
k_buffer = k_buffer.view(-1, 1, layer.qk_head_dim)
else:
self.logits_soft_cap = layer.logit_cap
paged_attention_ragged(
o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
self.workspace_buffer,
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view(
-1, 1, layer.tp_k_head_num, layer.qk_head_dim
),
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view(
-1, 1, layer.tp_v_head_num, layer.v_head_dim
),
self.scale,
self.forward_metadata.kv_indptr,
self.forward_metadata.kv_indices,
self.kv_last_page_len,
1,
self.max_num_partitions,
None,
"auto",
"NHD",
self.logits_soft_cap,
self.k_scale,
self.v_scale,
None,
_AITER_PARTITION_SIZE_ROCM,
)
return o
......@@ -506,9 +782,97 @@ class AiterIndicesUpdaterPrefill:
spec_info.generate_attn_arg_prefill(
req_pool_indices,
paged_kernel_lens,
None,
paged_kernel_lens_sum,
self.req_to_token,
)
)
self.kv_indices = kv_indices
class AiterMlaIndicesUpdaterPrefill:
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
# Parse Constants
self.attn_backend = attn_backend
# Buffers and wrappers
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.update = self.update_single_wrapper
self.kv_indptr = None
self.kv_indices = None
self.qo_indptr = None
self.kv_last_page_len = None
self.max_extend_len = 0
self.max_prefix_extend_len = 0
def update(
self,
req_pool_indices: torch.Tensor,
prefix_lens: torch.Tensor,
prefix_lens_sum: int,
extend_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
):
# Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError()
def update_single_wrapper(
self,
req_pool_indices: torch.Tensor,
prefix_lens: torch.Tensor,
prefix_lens_sum: int,
extend_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
):
paged_kernel_lens = prefix_lens
paged_kernel_lens_sum = prefix_lens_sum
bs = len(req_pool_indices)
kv_indptr = self.attn_backend.kv_indptr
if spec_info is None:
# Normal extend
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
paged_kernel_lens_sum,
dtype=torch.int32,
device=req_pool_indices.device,
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
qo_indptr = self.attn_backend.qo_indptr
qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1]
max_extend_len = torch.max(extend_lens).item()
max_prefix_extend_len = torch.max(extend_lens + paged_kernel_lens).item()
kv_indptr += qo_indptr
else:
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
req_pool_indices,
paged_kernel_lens,
paged_kernel_lens_sum,
self.req_to_token,
)
)
self.kv_indptr = kv_indptr
self.kv_indices = kv_indices
self.qo_indptr = qo_indptr
self.max_extend_len = max_extend_len
self.max_prefix_extend_len = max_prefix_extend_len
......@@ -20,10 +20,11 @@ import torch
import torch.nn as nn
from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda, is_hip
from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip
_is_cuda = is_cuda()
_is_hip = is_hip()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_cuda:
from sgl_kernel import (
......@@ -33,7 +34,10 @@ if _is_cuda:
rmsnorm,
)
if _is_hip:
if _use_aiter:
from aiter import rmsnorm2d_fwd as rms_norm
from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
elif _is_hip:
from vllm._custom_ops import fused_add_rms_norm, rms_norm
logger = logging.getLogger(__name__)
......@@ -48,6 +52,8 @@ class RMSNorm(CustomOp):
super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps
if _use_aiter:
self._forward_method = self.forward_aiter
def forward_cuda(
self,
......@@ -60,6 +66,25 @@ class RMSNorm(CustomOp):
out = rmsnorm(x, self.weight.data, self.variance_epsilon)
return out
def forward_aiter(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
residual_out = torch.empty_like(x)
output = torch.empty_like(x)
fused_add_rms_norm(
output,
x,
residual,
residual_out,
self.weight.data,
self.variance_epsilon,
)
return output, residual_out
return rms_norm(x, self.weight.data, self.variance_epsilon)
def forward_hip(
self,
x: torch.Tensor,
......
......@@ -1332,7 +1332,7 @@ def fused_experts_impl(
if (
not (use_fp8_w8a8 or use_int8_w8a8)
or block_shape is not None
or (_is_hip and get_bool_env_var("SGLANG_AITER_MOE"))
or (_is_hip and get_bool_env_var("SGLANG_USE_AITER"))
):
padded_size = 0
......
......@@ -28,8 +28,9 @@ else:
import logging
_is_hip = is_hip()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_hip:
if _use_aiter:
from aiter import ActivationType
from aiter.fused_moe_bf16_asm import ck_moe_2stages
from aiter.ops.shuffle import shuffle_weight
......@@ -104,7 +105,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
set_weight_attrs(w2_weight, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
if _use_aiter:
layer.w13_weight = torch.nn.Parameter(
shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False,
......@@ -188,7 +189,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
routed_scaling_factor=routed_scaling_factor,
)
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
if _use_aiter:
assert not no_combine, "unsupported"
if apply_router_weight_on_input:
assert (
......
......@@ -77,8 +77,8 @@ _is_cuda = is_cuda()
_is_fp8_fnuz = is_fp8_fnuz()
use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT")
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
_use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT")
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_hip:
from aiter import ActivationType, QuantType
......@@ -487,7 +487,7 @@ class Fp8MoEMethod:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.uint32 if use_hip_int4 else torch.float8_e4m3fn
params_dtype = torch.uint32 if _use_hip_int4 else torch.float8_e4m3fn
tp_size = get_tensor_model_parallel_world_size()
if self.block_quant:
block_n, block_k = (
......@@ -512,7 +512,7 @@ class Fp8MoEMethod:
)
# WEIGHTS
if _is_hip and use_hip_int4:
if _is_hip and _use_hip_int4:
# INT4 MoE weight - INT32 packed
w13_weight = torch.nn.Parameter(
torch.empty(
......@@ -641,7 +641,7 @@ class Fp8MoEMethod:
layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
if _is_hip: # and use_aiter_moe: TODO: add check back after triton kernel
if _is_hip: # _use_aiter: TODO: add check back after triton kernel
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
w13_weight_scale1 = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
......@@ -668,7 +668,7 @@ class Fp8MoEMethod:
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
if _is_hip and use_hip_int4:
if _is_hip and _use_hip_int4:
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
......@@ -700,7 +700,7 @@ class Fp8MoEMethod:
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None:
if _is_hip and use_hip_int4:
if _is_hip and _use_hip_int4:
self.process_weights_hip_int4(layer)
return
......@@ -731,7 +731,7 @@ class Fp8MoEMethod:
)
layer.w2_input_scale = None
if _is_hip and use_aiter_moe:
if _use_aiter:
# Pre-shuffle weights
layer.w13_weight.data = shuffle_weight(
layer.w13_weight.contiguous(), (16, 16)
......@@ -853,7 +853,7 @@ class Fp8MoEMethod:
return
def process_weights_hip_int4(self, layer: Module):
# TODO: and use_aiter_moe: add after triton kernel added
# TODO: _use_aiter: add after triton kernel added
# INT4-FP8 (INT4 MoE Weight, FP8 Compute)
# Weight Permutation
layer.w13_weight = torch.nn.Parameter(
......@@ -900,7 +900,7 @@ class Fp8MoEMethod:
padding_size, # Avoid circular import
)
if use_aiter_moe:
if _use_aiter:
layer.w13_weight = torch.nn.Parameter(
shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False,
......@@ -911,7 +911,7 @@ class Fp8MoEMethod:
requires_grad=False,
)
torch.cuda.empty_cache()
# ROCm (use_aiter_moe): using column-wise scaling
# ROCm (_use_aiter): using column-wise scaling
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
elif get_bool_env_var("SGLANG_MOE_PADDING"):
......@@ -1041,8 +1041,8 @@ class Fp8MoEMethod:
activation: str = "silu",
no_combine: bool = False,
) -> Optional[torch.Tensor]:
if use_hip_int4:
# TODO: add triton kernel and add check use_aiter_moe
if _use_hip_int4:
# TODO: add triton kernel and add check _use_aiter
assert not no_combine, f"{no_combine=} is not supported."
return ck_moe_2stages(
x,
......@@ -1058,13 +1058,13 @@ class Fp8MoEMethod:
),
)
if use_aiter_moe:
if _use_aiter:
assert not no_combine, f"{no_combine=} is not supported."
if self.block_quant:
# TODO(use_aiter_moe): FP8 block_quant only supports 'silu' for the time-being.
# TODO(_use_aiter): FP8 block_quant only supports 'silu' for the time-being.
assert (
activation == "silu"
), f"use_aiter_moe: FP8 bloack_quant {activation=} will be supported later, unset use_aiter_moe"
), f"_use_aiter: FP8 bloack_quant {activation=} will be supported later, unset _use_aiter"
return asm_moe(
x,
layer.w13_weight,
......
......@@ -38,11 +38,10 @@ _is_hip = is_hip()
_is_cuda = is_cuda()
_is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
if _is_hip and use_aiter_moe:
from aiter import gemm_a8w8_blockscale
if _use_aiter:
from aiter import gemm_a8w8_blockscale_CK
if _is_cuda:
from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
......@@ -141,7 +140,7 @@ def dispatch_w8a8_block_fp8_linear() -> Callable:
return flashinfer_gemm_w8a8_block_fp8_linear
elif CUTLASS_BLOCK_FP8_SUPPORTED:
return cutlass_w8a8_block_fp8_linear_with_fallback
elif _is_hip and use_aiter_moe:
elif _use_aiter:
return aiter_w8a8_block_fp8_linear
elif _ENABLE_JIT_DEEPGEMM:
return deepgemm_w8a8_block_fp8_linear_with_fallback
......@@ -268,12 +267,9 @@ def aiter_w8a8_block_fp8_linear(
q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False
)
output = torch.zeros(
[q_input.shape[0], weight.shape[0]],
dtype=input_2d.dtype,
device=q_input.device,
output = gemm_a8w8_blockscale_CK(
q_input, weight, x_scale, weight_scale, dtype=input.dtype
)
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
if bias is not None:
output += bias
......
......@@ -355,6 +355,15 @@ class ModelRunner:
# MLA architecture
if is_hopper_with_cuda_12_3():
server_args.attention_backend = "fa3"
elif _is_hip:
head_num = self.model_config.get_num_kv_heads(self.tp_size)
# TODO current aiter only support head number 16 or 128 head number
if (
head_num == 128 or head_num == 16
) and self.spec_algorithm.is_none():
server_args.attention_backend = "aiter"
else:
server_args.attention_backend = "triton"
else:
server_args.attention_backend = "triton"
logger.info(
......@@ -363,6 +372,7 @@ class ModelRunner:
elif self.use_mla_backend:
if server_args.device != "cpu":
if server_args.attention_backend in [
"aiter",
"flashinfer",
"fa3",
"triton",
......
......@@ -105,6 +105,7 @@ from sglang.srt.utils import (
_is_hip = is_hip()
_is_cuda = is_cuda()
_is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_cuda:
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
......@@ -120,6 +121,9 @@ if _is_hip:
decode_attention_fwd_grouped_rope,
)
if _use_aiter:
from aiter.rotary_embedding import get_rope
logger = logging.getLogger(__name__)
......@@ -697,6 +701,7 @@ class DeepseekV2AttentionMLA(nn.Module):
)
self.alt_stream = alt_stream
self.attn_mha.kv_b_proj = None
self.w_kc = None
self.w_vc = None
......@@ -766,6 +771,15 @@ class DeepseekV2AttentionMLA(nn.Module):
return AttnForwardMethod.MHA_CHUNKED_KV
else:
return _dispatch_mla_subtype()
elif self.attention_backend == "aiter":
if (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
return AttnForwardMethod.MHA
else:
return AttnForwardMethod.MLA
else:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
if (
......@@ -813,6 +827,9 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch: ForwardBatch,
zero_allocator: BumpAllocator,
):
if self.attn_mha.kv_b_proj is None:
self.attn_mha.kv_b_proj = self.kv_b_proj
if hidden_states.shape[0] == 0:
assert (
not self.o_proj.reduce_results
......
#!/bin/bash
set -euo pipefail
# Default working directory
WORKDIR="/sglang-checkout/test/srt"
ENV_ARGS=(
-e SGLANG_AMD_CI=1
-e SGLANG_IS_IN_CI=1
-e SGLANG_AITER_MOE=1
declare -A ENV_MAP=(
[SGLANG_AMD_CI]=1
[SGLANG_IS_IN_CI]=1
[SGLANG_USE_AITER]=1
)
# Parse optional -w/--workdir and -e ENV=VAL flags
# Parse -w/--workdir and -e ENV=VAL
while [[ $# -gt 0 ]]; do
case "$1" in
-w|--workdir)
......@@ -17,7 +16,8 @@ while [[ $# -gt 0 ]]; do
shift 2
;;
-e)
ENV_ARGS+=("-e" "$2")
IFS="=" read -r key val <<< "$2"
ENV_MAP["$key"]="$val"
shift 2
;;
--)
......@@ -30,6 +30,12 @@ while [[ $# -gt 0 ]]; do
esac
done
# Build final ENV_ARGS
ENV_ARGS=()
for key in "${!ENV_MAP[@]}"; do
ENV_ARGS+=("-e" "$key=${ENV_MAP[$key]}")
done
# Run docker exec
docker exec \
-w "$WORKDIR" \
......
......@@ -171,7 +171,7 @@ class TestNightlyGsm8KEval(unittest.TestCase):
os.environ["HF_HUB_DISABLE_XET"] = (
"1" if model in DISABLE_HF_XET_MODELS else "0"
)
os.environ["SGLANG_AITER_MOE"] = (
os.environ["SGLANG_USE_AITER"] = (
"0" if model in TRITON_MOE_MODELS else "1"
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment