Unverified Commit b819381f authored by HAI's avatar HAI Committed by GitHub
Browse files

AITER backend extension and workload optimizations (#6838)


Co-authored-by: default avatarwunhuang <wunhuang@amd.com>
Co-authored-by: default avatarHubert Lu <Hubert.Lu@amd.com>
parent 562f279a
...@@ -72,7 +72,7 @@ jobs: ...@@ -72,7 +72,7 @@ jobs:
- name: Evaluate accuracy (TP=2) - name: Evaluate accuracy (TP=2)
timeout-minutes: 30 timeout-minutes: 30
run: | run: |
bash scripts/amd_ci_exec.sh python3 test_moe_eval_accuracy_large.py bash scripts/amd_ci_exec.sh -e SGLANG_USE_AITER=0 python3 test_moe_eval_accuracy_large.py
mla-test-1-gpu-amd: mla-test-1-gpu-amd:
if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') && if: (github.repository == 'sgl-project/sglang' || github.event_name == 'pull_request') &&
......
...@@ -53,7 +53,7 @@ SGLang supports various environment variables that can be used to configure its ...@@ -53,7 +53,7 @@ SGLang supports various environment variables that can be used to configure its
| Environment Variable | Description | Default Value | | Environment Variable | Description | Default Value |
| --- | --- | --- | | --- | --- | --- |
| `SGLANG_AITER_MOE` | Use AITER MOE implementation | `false` | | `SGLANG_USE_AITER` | Use AITER optimize implementation | `false` |
| `SGLANG_INT4_WEIGHT` | Enable INT4 weight quantization | `false` | | `SGLANG_INT4_WEIGHT` | Enable INT4 weight quantization | `false` |
| `SGLANG_MOE_PADDING` | Enable MoE padding (sets padding size to 128 if value is `1`, often set to `1` in Docker builds) | `0` | | `SGLANG_MOE_PADDING` | Enable MoE padding (sets padding size to 128 if value is `1`, often set to `1` in Docker builds) | `0` |
| `SGLANG_FORCE_FP8_MARLIN` | Force using FP8 MARLIN kernels even if other FP8 kernels are available | `false` | | `SGLANG_FORCE_FP8_MARLIN` | Force using FP8 MARLIN kernels even if other FP8 kernels are available | `false` |
......
...@@ -27,12 +27,19 @@ if TYPE_CHECKING: ...@@ -27,12 +27,19 @@ if TYPE_CHECKING:
from sglang.srt.speculative.spec_info import SpecInfo from sglang.srt.speculative.spec_info import SpecInfo
try: try:
from aiter import mha_batch_prefill_func, paged_attention_ragged from aiter import (
flash_attn_varlen_func,
mha_batch_prefill_func,
paged_attention_ragged,
)
from aiter.mla import mla_decode_fwd
except ImportError: except ImportError:
print( print(
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device." "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
) )
from sglang.srt.configs.model_config import AttentionArch
class WrapperDispatch(Enum): class WrapperDispatch(Enum):
SLIDING_WINDOW = auto() SLIDING_WINDOW = auto()
...@@ -43,6 +50,10 @@ class WrapperDispatch(Enum): ...@@ -43,6 +50,10 @@ class WrapperDispatch(Enum):
class ForwardMetadata: class ForwardMetadata:
kv_indptr: torch.Tensor kv_indptr: torch.Tensor
kv_indices: torch.Tensor kv_indices: torch.Tensor
qo_indptr: torch.Tensor
kv_last_page_len: torch.Tensor
max_extend_len: int
max_prefix_extend_len: int
max_q_len: int max_q_len: int
max_kv_len: int max_kv_len: int
...@@ -63,6 +74,7 @@ class AiterAttnBackend(AttentionBackend): ...@@ -63,6 +74,7 @@ class AiterAttnBackend(AttentionBackend):
self.device = model_runner.device self.device = model_runner.device
self.is_multimodal = model_runner.model_config.is_multimodal self.is_multimodal = model_runner.model_config.is_multimodal
self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
self.num_head = ( self.num_head = (
model_runner.model_config.num_attention_heads // get_attention_tp_size() model_runner.model_config.num_attention_heads // get_attention_tp_size()
) )
...@@ -75,6 +87,8 @@ class AiterAttnBackend(AttentionBackend): ...@@ -75,6 +87,8 @@ class AiterAttnBackend(AttentionBackend):
self.req_to_token = model_runner.req_to_token_pool.req_to_token self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
# Parse constants # Parse constants
self.max_context_len = model_runner.model_config.context_len self.max_context_len = model_runner.model_config.context_len
self.skip_prefill = skip_prefill self.skip_prefill = skip_prefill
...@@ -100,6 +114,10 @@ class AiterAttnBackend(AttentionBackend): ...@@ -100,6 +114,10 @@ class AiterAttnBackend(AttentionBackend):
self.indices_updater_prefill = AiterIndicesUpdaterPrefill( self.indices_updater_prefill = AiterIndicesUpdaterPrefill(
model_runner, self model_runner, self
) )
if self.use_mla:
self.mla_indices_updater_prefill = AiterMlaIndicesUpdaterPrefill(
model_runner, self
)
# aiter kernel related initialization # aiter kernel related initialization
self.max_num_partitions = ( self.max_num_partitions = (
...@@ -108,33 +126,40 @@ class AiterAttnBackend(AttentionBackend): ...@@ -108,33 +126,40 @@ class AiterAttnBackend(AttentionBackend):
nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8 nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8
self.workspace_buffer = torch.empty( if not self.use_mla:
(max_bs * self.num_head * self.max_num_partitions * self.head_dim) self.workspace_buffer = torch.empty(
* nbyes_per_qo_elem (max_bs * self.num_head * self.max_num_partitions * self.head_dim)
+ 2 * (max_bs * self.num_head * self.max_num_partitions) * 4, * nbyes_per_qo_elem
dtype=torch.uint8, + 2 * (max_bs * self.num_head * self.max_num_partitions) * 4,
device=self.device, dtype=torch.uint8,
) device=self.device,
)
self.scale = float(1.0 / (self.head_dim**0.5)) self.scale = float(1.0 / (self.head_dim**0.5))
self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to( self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to(
self.device self.device
) )
self.kv_last_page_lens = torch.ones((max_bs,), dtype=torch.int32).to(
self.device
)
self.logits_soft_cap = 0.0 self.logits_soft_cap = 0.0
self.forward_metadata: ForwardMetadata = None self.forward_metadata: ForwardMetadata = None
if self.use_mla:
self.qo_indptr_ = torch.zeros(
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
)
def init_forward_metadata(self, forward_batch: ForwardBatch): def init_forward_metadata(self, forward_batch: ForwardBatch):
"""Init auxiliary variables for triton attention backend."""
bs = forward_batch.batch_size
kv_indptr = self.kv_indptr
spec_info = forward_batch.spec_info
qo_indptr = None
kv_last_page_len = None
max_extend_len = None
if forward_batch.forward_mode.is_decode_or_idle(): if forward_batch.forward_mode.is_decode_or_idle():
# update for aiter
# create kv_indices and kv_inptr
bs = forward_batch.batch_size
kv_indptr = self.kv_indptr
spec_info = forward_batch.spec_info
if spec_info is None: if spec_info is None:
kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1] kv_indptr = kv_indptr[: bs + 1]
...@@ -154,38 +179,103 @@ class AiterAttnBackend(AttentionBackend): ...@@ -154,38 +179,103 @@ class AiterAttnBackend(AttentionBackend):
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
bs = kv_indptr.shape[0] - 1 bs = kv_indptr.shape[0] - 1
self.forward_metadata = ForwardMetadata(kv_indptr, kv_indices, None, None) if self.use_mla:
qo_indptr = self.qo_indptr_[: bs + 1]
qo_indptr[1 : bs + 1] = torch.cumsum(self.kv_last_page_len[:bs], dim=0)
kv_last_page_len = self.kv_last_page_len[:bs]
max_extend_len = 1
elif forward_batch.forward_mode.is_draft_extend():
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens=None,
encoder_lens=forward_batch.encoder_lens,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = ForwardMetadata( self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indptr, kv_indptr,
self.indices_updater_prefill.kv_indices, kv_indices,
self.indices_updater_prefill.max_q_len, qo_indptr,
self.indices_updater_prefill.max_kv_len, kv_last_page_len,
max_extend_len,
None,
None,
None,
) )
elif forward_batch.forward_mode.is_draft_extend():
if self.use_mla:
prefix_lens = forward_batch.extend_prefix_lens
self.mla_indices_updater_prefill.update(
forward_batch.req_pool_indices,
prefix_lens,
prefix_lens.sum().item(),
forward_batch.extend_seq_lens,
encoder_lens=forward_batch.encoder_lens,
spec_info=None,
)
self.forward_metadata = ForwardMetadata(
self.mla_indices_updater_prefill.kv_indptr,
self.mla_indices_updater_prefill.kv_indices,
self.mla_indices_updater_prefill.qo_indptr,
self.mla_indices_updater_prefill.kv_last_page_len,
self.mla_indices_updater_prefill.max_extend_len,
self.mla_indices_updater_prefill.max_prefix_extend_len,
None,
None,
)
else:
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens=None,
encoder_lens=forward_batch.encoder_lens,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indptr,
self.indices_updater_prefill.kv_indices,
None,
None,
None,
None,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
)
elif forward_batch.forward_mode.is_target_verify(): elif forward_batch.forward_mode.is_target_verify():
self.indices_updater_prefill.update( if self.use_mla:
forward_batch.req_pool_indices, prefix_lens = forward_batch.extend_prefix_lens
forward_batch.seq_lens, self.mla_indices_updater_prefill.update(
forward_batch.seq_lens_sum, forward_batch.req_pool_indices,
prefix_lens=None, prefix_lens,
encoder_lens=forward_batch.encoder_lens, prefix_lens.sum().item(),
spec_info=forward_batch.spec_info, forward_batch.extend_seq_lens,
) encoder_lens=forward_batch.encoder_lens,
self.forward_metadata = ForwardMetadata( spec_info=None,
self.indices_updater_prefill.kv_indptr, )
self.indices_updater_prefill.kv_indices, self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.max_q_len, self.mla_indices_updater_prefill.kv_indptr,
self.indices_updater_prefill.max_kv_len, self.mla_indices_updater_prefill.kv_indices,
) self.mla_indices_updater_prefill.qo_indptr,
self.mla_indices_updater_prefill.kv_last_page_len,
self.mla_indices_updater_prefill.max_extend_len,
self.mla_indices_updater_prefill.max_prefix_extend_len,
None,
None,
)
else:
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens=None,
encoder_lens=forward_batch.encoder_lens,
spec_info=forward_batch.spec_info,
)
self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indptr,
self.indices_updater_prefill.kv_indices,
None,
None,
None,
None,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
)
else: else:
prefix_lens = forward_batch.extend_prefix_lens prefix_lens = forward_batch.extend_prefix_lens
...@@ -194,24 +284,49 @@ class AiterAttnBackend(AttentionBackend): ...@@ -194,24 +284,49 @@ class AiterAttnBackend(AttentionBackend):
else: else:
extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu) extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
self.indices_updater_prefill.update( if self.use_mla:
forward_batch.req_pool_indices, self.mla_indices_updater_prefill.update(
forward_batch.seq_lens, forward_batch.req_pool_indices,
forward_batch.seq_lens_sum, prefix_lens,
prefix_lens, prefix_lens.sum().item(),
encoder_lens=forward_batch.encoder_lens, forward_batch.extend_seq_lens,
spec_info=None, encoder_lens=forward_batch.encoder_lens,
) spec_info=None,
self.forward_metadata = ForwardMetadata( )
self.indices_updater_prefill.kv_indptr, self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indices, self.mla_indices_updater_prefill.kv_indptr,
self.indices_updater_prefill.max_q_len, self.mla_indices_updater_prefill.kv_indices,
self.indices_updater_prefill.max_kv_len, self.mla_indices_updater_prefill.qo_indptr,
) self.mla_indices_updater_prefill.kv_last_page_len,
self.mla_indices_updater_prefill.max_extend_len,
self.mla_indices_updater_prefill.max_prefix_extend_len,
None,
None,
)
else:
self.indices_updater_prefill.update(
forward_batch.req_pool_indices,
forward_batch.seq_lens,
forward_batch.seq_lens_sum,
prefix_lens,
encoder_lens=forward_batch.encoder_lens,
spec_info=None,
)
self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indptr,
self.indices_updater_prefill.kv_indices,
None,
None,
None,
None,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
)
def init_cuda_graph_state( def init_cuda_graph_state(
self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
): ):
self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int)
if kv_indices_buf is None: if kv_indices_buf is None:
self.cuda_graph_kv_indices = torch.zeros( self.cuda_graph_kv_indices = torch.zeros(
(max_bs * self.max_context_len), (max_bs * self.max_context_len),
...@@ -239,6 +354,10 @@ class AiterAttnBackend(AttentionBackend): ...@@ -239,6 +354,10 @@ class AiterAttnBackend(AttentionBackend):
spec_info: Optional[SpecInfo], spec_info: Optional[SpecInfo],
): ):
if forward_mode.is_decode_or_idle(): if forward_mode.is_decode_or_idle():
qo_indptr = None
kv_last_page_len = None
max_extend_len = None
if spec_info is None: if spec_info is None:
kv_indptr = self.kv_indptr kv_indptr = self.kv_indptr
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
...@@ -255,25 +374,83 @@ class AiterAttnBackend(AttentionBackend): ...@@ -255,25 +374,83 @@ class AiterAttnBackend(AttentionBackend):
) )
else: else:
kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
self.forward_metadata = ForwardMetadata(kv_indptr, kv_indices, None, None)
elif forward_mode.is_target_verify(): if self.use_mla:
seq_lens_sum = seq_lens.sum().item() qo_indptr = self.qo_indptr_[: bs + 1]
self.indices_updater_prefill.update( qo_indptr[1 : bs + 1] = torch.cumsum(
req_pool_indices, self.cuda_graph_kv_last_page_len[:bs], dim=0
seq_lens, )
seq_lens_sum, max_extend_len = 1
prefix_lens=None, kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
encoder_lens=encoder_lens,
spec_info=spec_info,
)
self.forward_metadata = ForwardMetadata( self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indptr, kv_indptr,
self.indices_updater_prefill.kv_indices, kv_indices,
self.indices_updater_prefill.max_q_len, qo_indptr,
self.indices_updater_prefill.max_kv_len, kv_last_page_len,
max_extend_len,
None,
None,
None,
) )
elif forward_mode.is_target_verify():
if self.use_mla:
qo_indptr = self.qo_indptr[: bs + 1]
qo_indptr[: bs + 1] = torch.arange(
0,
(1 + bs) * self.num_draft_tokens,
step=self.num_draft_tokens,
dtype=torch.int32,
device=self.device,
)
kv_indptr = self.kv_indptr[: bs + 1]
kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
kv_indices = self.cuda_graph_kv_indices
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
seq_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
max_extend_len = self.num_draft_tokens
kv_last_page_len = None
self.forward_metadata = ForwardMetadata(
kv_indptr,
kv_indices,
qo_indptr,
kv_last_page_len,
max_extend_len,
None,
None,
None,
)
else:
seq_lens_sum = seq_lens.sum().item()
self.indices_updater_prefill.update(
req_pool_indices,
seq_lens,
seq_lens_sum,
prefix_lens=None,
encoder_lens=encoder_lens,
spec_info=spec_info,
)
self.forward_metadata = ForwardMetadata(
self.indices_updater_prefill.kv_indptr,
self.indices_updater_prefill.kv_indices,
None,
None,
None,
None,
self.indices_updater_prefill.max_q_len,
self.indices_updater_prefill.max_kv_len,
)
else: else:
raise ValueError(f"Invalid mode: {forward_mode=}") raise ValueError(f"Invalid mode: {forward_mode=}")
...@@ -342,31 +519,113 @@ class AiterAttnBackend(AttentionBackend): ...@@ -342,31 +519,113 @@ class AiterAttnBackend(AttentionBackend):
if k is not None: if k is not None:
assert v is not None assert v is not None
if save_kv_cache: if save_kv_cache:
forward_batch.token_to_kv_pool.set_kv_buffer( if self.use_mla:
layer, cache_loc, k, v, layer.k_scale, layer.v_scale forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
) else:
forward_batch.token_to_kv_pool.set_kv_buffer(
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id) layer, cache_loc, k, v, layer.k_scale, layer.v_scale
)
bs0 = forward_batch.batch_size + 1
if self.use_mla:
o = mha_batch_prefill_func( max_extend_len = self.forward_metadata.max_extend_len
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim), max_prefix_extend_len = self.forward_metadata.max_prefix_extend_len
k_cache, kv_indptr = self.forward_metadata.kv_indptr
v_cache, kv_indices = self.forward_metadata.kv_indices
self.qo_indptr[:bs0], kv_last_page_lens = self.forward_metadata.kv_last_page_len
self.forward_metadata.kv_indptr[:bs0], qo_indptr = self.forward_metadata.qo_indptr
self.forward_metadata.kv_indices, K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
self.forward_metadata.max_q_len, V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
self.forward_metadata.max_kv_len, kv_lora_rank = V_Buffer.shape[-1]
causal=True, qk_rope_head_dim = K_Buffer.shape[-1] - kv_lora_rank
logits_soft_cap=self.logits_soft_cap, qk_nope_head_dim = k.shape[-1] - qk_rope_head_dim
alibi_slopes=None, assert len(q.shape) == 3
return_lse=False, assert len(k.shape) == 3
return_attn_probs=False, assert len(v.shape) == 3
)
if kv_indices.shape[0] == 0:
o = flash_attn_varlen_func(
q,
k,
v,
qo_indptr,
qo_indptr,
max_extend_len,
max_extend_len,
softmax_scale=layer.scaling,
causal=True,
)
return o
elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim):
K_Buffer = torch.index_select(K_Buffer, 0, kv_indices)
kvc, k_pe = torch.split(
K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1
)
kvprefix = layer.kv_b_proj(kvc.contiguous())[0]
return o.view(-1, layer.tp_q_head_num * layer.head_dim) kvprefix = kvprefix.view(
-1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim
)
k_prefix, v_prefix = torch.split(
kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1
)
k_prefix = torch.cat(
[
k_prefix,
torch.broadcast_to(
k_pe,
(k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]),
),
],
dim=-1,
)
assert (
forward_batch.extend_prefix_lens.shape
== forward_batch.extend_seq_lens.shape
)
k_prefix = torch.split(k_prefix, forward_batch.extend_prefix_lens_cpu)
k_extend = torch.split(k, forward_batch.extend_seq_lens_cpu)
assert len(k_prefix) == len(forward_batch.extend_prefix_lens_cpu)
k = torch.cat([x for el in zip(k_prefix, k_extend) for x in el])
v_prefix = torch.split(v_prefix, forward_batch.extend_prefix_lens_cpu)
v_extend = torch.split(v, forward_batch.extend_seq_lens_cpu)
v = torch.cat([x for el in zip(v_prefix, v_extend) for x in el])
o = flash_attn_varlen_func(
q,
k,
v,
qo_indptr,
kv_indptr,
max_extend_len,
max_prefix_extend_len,
softmax_scale=layer.scaling,
causal=True,
)
return o
else:
k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
layer.layer_id
)
bs0 = forward_batch.batch_size + 1
o = mha_batch_prefill_func(
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
k_cache,
v_cache,
self.qo_indptr[:bs0],
self.forward_metadata.kv_indptr[:bs0],
self.forward_metadata.kv_indices,
self.forward_metadata.max_q_len,
self.forward_metadata.max_kv_len,
causal=True,
logits_soft_cap=self.logits_soft_cap,
alibi_slopes=None,
return_lse=False,
return_attn_probs=False,
)
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
def forward_decode( def forward_decode(
self, self,
...@@ -377,6 +636,7 @@ class AiterAttnBackend(AttentionBackend): ...@@ -377,6 +636,7 @@ class AiterAttnBackend(AttentionBackend):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
save_kv_cache=True, save_kv_cache=True,
): ):
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim) q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
if layer.qk_head_dim != layer.v_head_dim: if layer.qk_head_dim != layer.v_head_dim:
...@@ -389,32 +649,48 @@ class AiterAttnBackend(AttentionBackend): ...@@ -389,32 +649,48 @@ class AiterAttnBackend(AttentionBackend):
layer, forward_batch.out_cache_loc, k, v layer, forward_batch.out_cache_loc, k, v
) )
self.logits_soft_cap = layer.logit_cap if self.use_mla:
paged_attention_ragged( k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
o.view(-1, layer.tp_q_head_num, layer.qk_head_dim), mla_decode_fwd(
self.workspace_buffer, q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), k_buffer.view(-1, 1, 1, layer.qk_head_dim),
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view( o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
-1, 1, layer.tp_k_head_num, layer.qk_head_dim self.forward_metadata.qo_indptr,
), self.forward_metadata.kv_indptr,
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view( self.forward_metadata.kv_indices,
-1, 1, layer.tp_v_head_num, layer.v_head_dim self.forward_metadata.kv_last_page_len,
), self.forward_metadata.max_extend_len,
self.scale, layer.scaling,
self.forward_metadata.kv_indptr, layer.logit_cap,
self.forward_metadata.kv_indices, )
self.kv_last_page_lens, k_buffer = k_buffer.view(-1, 1, layer.qk_head_dim)
1, else:
self.max_num_partitions, self.logits_soft_cap = layer.logit_cap
None, paged_attention_ragged(
"auto", o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
"NHD", self.workspace_buffer,
self.logits_soft_cap, q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
self.k_scale, forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view(
self.v_scale, -1, 1, layer.tp_k_head_num, layer.qk_head_dim
None, ),
_AITER_PARTITION_SIZE_ROCM, forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view(
) -1, 1, layer.tp_v_head_num, layer.v_head_dim
),
self.scale,
self.forward_metadata.kv_indptr,
self.forward_metadata.kv_indices,
self.kv_last_page_len,
1,
self.max_num_partitions,
None,
"auto",
"NHD",
self.logits_soft_cap,
self.k_scale,
self.v_scale,
None,
_AITER_PARTITION_SIZE_ROCM,
)
return o return o
...@@ -506,9 +782,97 @@ class AiterIndicesUpdaterPrefill: ...@@ -506,9 +782,97 @@ class AiterIndicesUpdaterPrefill:
spec_info.generate_attn_arg_prefill( spec_info.generate_attn_arg_prefill(
req_pool_indices, req_pool_indices,
paged_kernel_lens, paged_kernel_lens,
None, paged_kernel_lens_sum,
self.req_to_token,
)
)
self.kv_indices = kv_indices
class AiterMlaIndicesUpdaterPrefill:
def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
# Parse Constants
self.attn_backend = attn_backend
# Buffers and wrappers
self.req_to_token = model_runner.req_to_token_pool.req_to_token
self.update = self.update_single_wrapper
self.kv_indptr = None
self.kv_indices = None
self.qo_indptr = None
self.kv_last_page_len = None
self.max_extend_len = 0
self.max_prefix_extend_len = 0
def update(
self,
req_pool_indices: torch.Tensor,
prefix_lens: torch.Tensor,
prefix_lens_sum: int,
extend_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
):
# Keep the signature for type checking. It will be assigned during runtime.
raise NotImplementedError()
def update_single_wrapper(
self,
req_pool_indices: torch.Tensor,
prefix_lens: torch.Tensor,
prefix_lens_sum: int,
extend_lens: torch.Tensor,
encoder_lens: Optional[torch.Tensor],
spec_info: Optional[SpecInfo],
):
paged_kernel_lens = prefix_lens
paged_kernel_lens_sum = prefix_lens_sum
bs = len(req_pool_indices)
kv_indptr = self.attn_backend.kv_indptr
if spec_info is None:
# Normal extend
kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
kv_indptr = kv_indptr[: bs + 1]
kv_indices = torch.empty(
paged_kernel_lens_sum,
dtype=torch.int32,
device=req_pool_indices.device,
)
create_flashinfer_kv_indices_triton[(bs,)](
self.req_to_token,
req_pool_indices,
paged_kernel_lens,
kv_indptr,
None,
kv_indices,
self.req_to_token.stride(0),
)
qo_indptr = self.attn_backend.qo_indptr
qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0)
qo_indptr = qo_indptr[: bs + 1]
max_extend_len = torch.max(extend_lens).item()
max_prefix_extend_len = torch.max(extend_lens + paged_kernel_lens).item()
kv_indptr += qo_indptr
else:
kv_indices, kv_indptr, qo_indptr, custom_mask = (
spec_info.generate_attn_arg_prefill(
req_pool_indices,
paged_kernel_lens,
paged_kernel_lens_sum,
self.req_to_token, self.req_to_token,
) )
) )
self.kv_indptr = kv_indptr
self.kv_indices = kv_indices self.kv_indices = kv_indices
self.qo_indptr = qo_indptr
self.max_extend_len = max_extend_len
self.max_prefix_extend_len = max_prefix_extend_len
...@@ -20,10 +20,11 @@ import torch ...@@ -20,10 +20,11 @@ import torch
import torch.nn as nn import torch.nn as nn
from sglang.srt.custom_op import CustomOp from sglang.srt.custom_op import CustomOp
from sglang.srt.utils import is_cuda, is_hip from sglang.srt.utils import get_bool_env_var, is_cuda, is_hip
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_hip = is_hip() _is_hip = is_hip()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_cuda: if _is_cuda:
from sgl_kernel import ( from sgl_kernel import (
...@@ -33,7 +34,10 @@ if _is_cuda: ...@@ -33,7 +34,10 @@ if _is_cuda:
rmsnorm, rmsnorm,
) )
if _is_hip: if _use_aiter:
from aiter import rmsnorm2d_fwd as rms_norm
from aiter import rmsnorm2d_fwd_with_add as fused_add_rms_norm
elif _is_hip:
from vllm._custom_ops import fused_add_rms_norm, rms_norm from vllm._custom_ops import fused_add_rms_norm, rms_norm
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -48,6 +52,8 @@ class RMSNorm(CustomOp): ...@@ -48,6 +52,8 @@ class RMSNorm(CustomOp):
super().__init__() super().__init__()
self.weight = nn.Parameter(torch.ones(hidden_size)) self.weight = nn.Parameter(torch.ones(hidden_size))
self.variance_epsilon = eps self.variance_epsilon = eps
if _use_aiter:
self._forward_method = self.forward_aiter
def forward_cuda( def forward_cuda(
self, self,
...@@ -60,6 +66,25 @@ class RMSNorm(CustomOp): ...@@ -60,6 +66,25 @@ class RMSNorm(CustomOp):
out = rmsnorm(x, self.weight.data, self.variance_epsilon) out = rmsnorm(x, self.weight.data, self.variance_epsilon)
return out return out
def forward_aiter(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if residual is not None:
residual_out = torch.empty_like(x)
output = torch.empty_like(x)
fused_add_rms_norm(
output,
x,
residual,
residual_out,
self.weight.data,
self.variance_epsilon,
)
return output, residual_out
return rms_norm(x, self.weight.data, self.variance_epsilon)
def forward_hip( def forward_hip(
self, self,
x: torch.Tensor, x: torch.Tensor,
......
...@@ -1332,7 +1332,7 @@ def fused_experts_impl( ...@@ -1332,7 +1332,7 @@ def fused_experts_impl(
if ( if (
not (use_fp8_w8a8 or use_int8_w8a8) not (use_fp8_w8a8 or use_int8_w8a8)
or block_shape is not None or block_shape is not None
or (_is_hip and get_bool_env_var("SGLANG_AITER_MOE")) or (_is_hip and get_bool_env_var("SGLANG_USE_AITER"))
): ):
padded_size = 0 padded_size = 0
......
...@@ -28,8 +28,9 @@ else: ...@@ -28,8 +28,9 @@ else:
import logging import logging
_is_hip = is_hip() _is_hip = is_hip()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_hip: if _use_aiter:
from aiter import ActivationType from aiter import ActivationType
from aiter.fused_moe_bf16_asm import ck_moe_2stages from aiter.fused_moe_bf16_asm import ck_moe_2stages
from aiter.ops.shuffle import shuffle_weight from aiter.ops.shuffle import shuffle_weight
...@@ -104,7 +105,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -104,7 +105,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
set_weight_attrs(w2_weight, extra_weight_attrs) set_weight_attrs(w2_weight, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"): if _use_aiter:
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
shuffle_weight(layer.w13_weight.data, (16, 16)), shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False, requires_grad=False,
...@@ -188,7 +189,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -188,7 +189,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
routed_scaling_factor=routed_scaling_factor, routed_scaling_factor=routed_scaling_factor,
) )
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"): if _use_aiter:
assert not no_combine, "unsupported" assert not no_combine, "unsupported"
if apply_router_weight_on_input: if apply_router_weight_on_input:
assert ( assert (
......
...@@ -77,8 +77,8 @@ _is_cuda = is_cuda() ...@@ -77,8 +77,8 @@ _is_cuda = is_cuda()
_is_fp8_fnuz = is_fp8_fnuz() _is_fp8_fnuz = is_fp8_fnuz()
use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT") _use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT")
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE") _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_hip: if _is_hip:
from aiter import ActivationType, QuantType from aiter import ActivationType, QuantType
...@@ -487,7 +487,7 @@ class Fp8MoEMethod: ...@@ -487,7 +487,7 @@ class Fp8MoEMethod:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
if self.quant_config.is_checkpoint_fp8_serialized: if self.quant_config.is_checkpoint_fp8_serialized:
params_dtype = torch.uint32 if use_hip_int4 else torch.float8_e4m3fn params_dtype = torch.uint32 if _use_hip_int4 else torch.float8_e4m3fn
tp_size = get_tensor_model_parallel_world_size() tp_size = get_tensor_model_parallel_world_size()
if self.block_quant: if self.block_quant:
block_n, block_k = ( block_n, block_k = (
...@@ -512,7 +512,7 @@ class Fp8MoEMethod: ...@@ -512,7 +512,7 @@ class Fp8MoEMethod:
) )
# WEIGHTS # WEIGHTS
if _is_hip and use_hip_int4: if _is_hip and _use_hip_int4:
# INT4 MoE weight - INT32 packed # INT4 MoE weight - INT32 packed
w13_weight = torch.nn.Parameter( w13_weight = torch.nn.Parameter(
torch.empty( torch.empty(
...@@ -641,7 +641,7 @@ class Fp8MoEMethod: ...@@ -641,7 +641,7 @@ class Fp8MoEMethod:
layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w13_weight_scale", w13_weight_scale)
layer.register_parameter("w2_weight_scale", w2_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale)
if _is_hip: # and use_aiter_moe: TODO: add check back after triton kernel if _is_hip: # _use_aiter: TODO: add check back after triton kernel
# ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling # ROCm - using column scaling, duplicate scaling numbers in case per tensor scaling
w13_weight_scale1 = torch.nn.Parameter( w13_weight_scale1 = torch.nn.Parameter(
torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32), torch.ones(num_experts, 2 * intermediate_size, dtype=torch.float32),
...@@ -668,7 +668,7 @@ class Fp8MoEMethod: ...@@ -668,7 +668,7 @@ class Fp8MoEMethod:
set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w13_weight_scale, extra_weight_attrs)
set_weight_attrs(w2_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs)
if _is_hip and use_hip_int4: if _is_hip and _use_hip_int4:
extra_weight_attrs.update( extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
) )
...@@ -700,7 +700,7 @@ class Fp8MoEMethod: ...@@ -700,7 +700,7 @@ class Fp8MoEMethod:
layer.w2_input_scale = None layer.w2_input_scale = None
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
if _is_hip and use_hip_int4: if _is_hip and _use_hip_int4:
self.process_weights_hip_int4(layer) self.process_weights_hip_int4(layer)
return return
...@@ -731,7 +731,7 @@ class Fp8MoEMethod: ...@@ -731,7 +731,7 @@ class Fp8MoEMethod:
) )
layer.w2_input_scale = None layer.w2_input_scale = None
if _is_hip and use_aiter_moe: if _use_aiter:
# Pre-shuffle weights # Pre-shuffle weights
layer.w13_weight.data = shuffle_weight( layer.w13_weight.data = shuffle_weight(
layer.w13_weight.contiguous(), (16, 16) layer.w13_weight.contiguous(), (16, 16)
...@@ -853,7 +853,7 @@ class Fp8MoEMethod: ...@@ -853,7 +853,7 @@ class Fp8MoEMethod:
return return
def process_weights_hip_int4(self, layer: Module): def process_weights_hip_int4(self, layer: Module):
# TODO: and use_aiter_moe: add after triton kernel added # TODO: _use_aiter: add after triton kernel added
# INT4-FP8 (INT4 MoE Weight, FP8 Compute) # INT4-FP8 (INT4 MoE Weight, FP8 Compute)
# Weight Permutation # Weight Permutation
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
...@@ -900,7 +900,7 @@ class Fp8MoEMethod: ...@@ -900,7 +900,7 @@ class Fp8MoEMethod:
padding_size, # Avoid circular import padding_size, # Avoid circular import
) )
if use_aiter_moe: if _use_aiter:
layer.w13_weight = torch.nn.Parameter( layer.w13_weight = torch.nn.Parameter(
shuffle_weight(layer.w13_weight.data, (16, 16)), shuffle_weight(layer.w13_weight.data, (16, 16)),
requires_grad=False, requires_grad=False,
...@@ -911,7 +911,7 @@ class Fp8MoEMethod: ...@@ -911,7 +911,7 @@ class Fp8MoEMethod:
requires_grad=False, requires_grad=False,
) )
torch.cuda.empty_cache() torch.cuda.empty_cache()
# ROCm (use_aiter_moe): using column-wise scaling # ROCm (_use_aiter): using column-wise scaling
layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1) layer.w13_weight_scale1 *= layer.w13_weight_scale.unsqueeze(-1)
layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1) layer.w2_weight_scale1 *= layer.w2_weight_scale.unsqueeze(-1)
elif get_bool_env_var("SGLANG_MOE_PADDING"): elif get_bool_env_var("SGLANG_MOE_PADDING"):
...@@ -1041,8 +1041,8 @@ class Fp8MoEMethod: ...@@ -1041,8 +1041,8 @@ class Fp8MoEMethod:
activation: str = "silu", activation: str = "silu",
no_combine: bool = False, no_combine: bool = False,
) -> Optional[torch.Tensor]: ) -> Optional[torch.Tensor]:
if use_hip_int4: if _use_hip_int4:
# TODO: add triton kernel and add check use_aiter_moe # TODO: add triton kernel and add check _use_aiter
assert not no_combine, f"{no_combine=} is not supported." assert not no_combine, f"{no_combine=} is not supported."
return ck_moe_2stages( return ck_moe_2stages(
x, x,
...@@ -1058,13 +1058,13 @@ class Fp8MoEMethod: ...@@ -1058,13 +1058,13 @@ class Fp8MoEMethod:
), ),
) )
if use_aiter_moe: if _use_aiter:
assert not no_combine, f"{no_combine=} is not supported." assert not no_combine, f"{no_combine=} is not supported."
if self.block_quant: if self.block_quant:
# TODO(use_aiter_moe): FP8 block_quant only supports 'silu' for the time-being. # TODO(_use_aiter): FP8 block_quant only supports 'silu' for the time-being.
assert ( assert (
activation == "silu" activation == "silu"
), f"use_aiter_moe: FP8 bloack_quant {activation=} will be supported later, unset use_aiter_moe" ), f"_use_aiter: FP8 bloack_quant {activation=} will be supported later, unset _use_aiter"
return asm_moe( return asm_moe(
x, x,
layer.w13_weight, layer.w13_weight,
......
...@@ -38,11 +38,10 @@ _is_hip = is_hip() ...@@ -38,11 +38,10 @@ _is_hip = is_hip()
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_fp8_fnuz = is_fp8_fnuz() _is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE") if _use_aiter:
from aiter import gemm_a8w8_blockscale_CK
if _is_hip and use_aiter_moe:
from aiter import gemm_a8w8_blockscale
if _is_cuda: if _is_cuda:
from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
...@@ -141,7 +140,7 @@ def dispatch_w8a8_block_fp8_linear() -> Callable: ...@@ -141,7 +140,7 @@ def dispatch_w8a8_block_fp8_linear() -> Callable:
return flashinfer_gemm_w8a8_block_fp8_linear return flashinfer_gemm_w8a8_block_fp8_linear
elif CUTLASS_BLOCK_FP8_SUPPORTED: elif CUTLASS_BLOCK_FP8_SUPPORTED:
return cutlass_w8a8_block_fp8_linear_with_fallback return cutlass_w8a8_block_fp8_linear_with_fallback
elif _is_hip and use_aiter_moe: elif _use_aiter:
return aiter_w8a8_block_fp8_linear return aiter_w8a8_block_fp8_linear
elif _ENABLE_JIT_DEEPGEMM: elif _ENABLE_JIT_DEEPGEMM:
return deepgemm_w8a8_block_fp8_linear_with_fallback return deepgemm_w8a8_block_fp8_linear_with_fallback
...@@ -268,12 +267,9 @@ def aiter_w8a8_block_fp8_linear( ...@@ -268,12 +267,9 @@ def aiter_w8a8_block_fp8_linear(
q_input, x_scale = per_token_group_quant_fp8( q_input, x_scale = per_token_group_quant_fp8(
input_2d, block_size[1], column_major_scales=False input_2d, block_size[1], column_major_scales=False
) )
output = torch.zeros( output = gemm_a8w8_blockscale_CK(
[q_input.shape[0], weight.shape[0]], q_input, weight, x_scale, weight_scale, dtype=input.dtype
dtype=input_2d.dtype,
device=q_input.device,
) )
gemm_a8w8_blockscale(q_input, weight, x_scale, weight_scale, output)
if bias is not None: if bias is not None:
output += bias output += bias
......
...@@ -355,6 +355,15 @@ class ModelRunner: ...@@ -355,6 +355,15 @@ class ModelRunner:
# MLA architecture # MLA architecture
if is_hopper_with_cuda_12_3(): if is_hopper_with_cuda_12_3():
server_args.attention_backend = "fa3" server_args.attention_backend = "fa3"
elif _is_hip:
head_num = self.model_config.get_num_kv_heads(self.tp_size)
# TODO current aiter only support head number 16 or 128 head number
if (
head_num == 128 or head_num == 16
) and self.spec_algorithm.is_none():
server_args.attention_backend = "aiter"
else:
server_args.attention_backend = "triton"
else: else:
server_args.attention_backend = "triton" server_args.attention_backend = "triton"
logger.info( logger.info(
...@@ -363,6 +372,7 @@ class ModelRunner: ...@@ -363,6 +372,7 @@ class ModelRunner:
elif self.use_mla_backend: elif self.use_mla_backend:
if server_args.device != "cpu": if server_args.device != "cpu":
if server_args.attention_backend in [ if server_args.attention_backend in [
"aiter",
"flashinfer", "flashinfer",
"fa3", "fa3",
"triton", "triton",
......
...@@ -105,6 +105,7 @@ from sglang.srt.utils import ( ...@@ -105,6 +105,7 @@ from sglang.srt.utils import (
_is_hip = is_hip() _is_hip = is_hip()
_is_cuda = is_cuda() _is_cuda = is_cuda()
_is_fp8_fnuz = is_fp8_fnuz() _is_fp8_fnuz = is_fp8_fnuz()
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if _is_cuda: if _is_cuda:
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2 from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
...@@ -120,6 +121,9 @@ if _is_hip: ...@@ -120,6 +121,9 @@ if _is_hip:
decode_attention_fwd_grouped_rope, decode_attention_fwd_grouped_rope,
) )
if _use_aiter:
from aiter.rotary_embedding import get_rope
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -697,6 +701,7 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -697,6 +701,7 @@ class DeepseekV2AttentionMLA(nn.Module):
) )
self.alt_stream = alt_stream self.alt_stream = alt_stream
self.attn_mha.kv_b_proj = None
self.w_kc = None self.w_kc = None
self.w_vc = None self.w_vc = None
...@@ -766,6 +771,15 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -766,6 +771,15 @@ class DeepseekV2AttentionMLA(nn.Module):
return AttnForwardMethod.MHA_CHUNKED_KV return AttnForwardMethod.MHA_CHUNKED_KV
else: else:
return _dispatch_mla_subtype() return _dispatch_mla_subtype()
elif self.attention_backend == "aiter":
if (
forward_batch.forward_mode.is_extend()
and not forward_batch.forward_mode.is_target_verify()
and not forward_batch.forward_mode.is_draft_extend()
):
return AttnForwardMethod.MHA
else:
return AttnForwardMethod.MLA
else: else:
# Triton: Use normal computation for prefill and use weight absorption for extend/decode # Triton: Use normal computation for prefill and use weight absorption for extend/decode
if ( if (
...@@ -813,6 +827,9 @@ class DeepseekV2AttentionMLA(nn.Module): ...@@ -813,6 +827,9 @@ class DeepseekV2AttentionMLA(nn.Module):
forward_batch: ForwardBatch, forward_batch: ForwardBatch,
zero_allocator: BumpAllocator, zero_allocator: BumpAllocator,
): ):
if self.attn_mha.kv_b_proj is None:
self.attn_mha.kv_b_proj = self.kv_b_proj
if hidden_states.shape[0] == 0: if hidden_states.shape[0] == 0:
assert ( assert (
not self.o_proj.reduce_results not self.o_proj.reduce_results
......
#!/bin/bash #!/bin/bash
set -euo pipefail set -euo pipefail
# Default working directory
WORKDIR="/sglang-checkout/test/srt" WORKDIR="/sglang-checkout/test/srt"
ENV_ARGS=( declare -A ENV_MAP=(
-e SGLANG_AMD_CI=1 [SGLANG_AMD_CI]=1
-e SGLANG_IS_IN_CI=1 [SGLANG_IS_IN_CI]=1
-e SGLANG_AITER_MOE=1 [SGLANG_USE_AITER]=1
) )
# Parse optional -w/--workdir and -e ENV=VAL flags # Parse -w/--workdir and -e ENV=VAL
while [[ $# -gt 0 ]]; do while [[ $# -gt 0 ]]; do
case "$1" in case "$1" in
-w|--workdir) -w|--workdir)
...@@ -17,7 +16,8 @@ while [[ $# -gt 0 ]]; do ...@@ -17,7 +16,8 @@ while [[ $# -gt 0 ]]; do
shift 2 shift 2
;; ;;
-e) -e)
ENV_ARGS+=("-e" "$2") IFS="=" read -r key val <<< "$2"
ENV_MAP["$key"]="$val"
shift 2 shift 2
;; ;;
--) --)
...@@ -30,6 +30,12 @@ while [[ $# -gt 0 ]]; do ...@@ -30,6 +30,12 @@ while [[ $# -gt 0 ]]; do
esac esac
done done
# Build final ENV_ARGS
ENV_ARGS=()
for key in "${!ENV_MAP[@]}"; do
ENV_ARGS+=("-e" "$key=${ENV_MAP[$key]}")
done
# Run docker exec # Run docker exec
docker exec \ docker exec \
-w "$WORKDIR" \ -w "$WORKDIR" \
......
...@@ -171,7 +171,7 @@ class TestNightlyGsm8KEval(unittest.TestCase): ...@@ -171,7 +171,7 @@ class TestNightlyGsm8KEval(unittest.TestCase):
os.environ["HF_HUB_DISABLE_XET"] = ( os.environ["HF_HUB_DISABLE_XET"] = (
"1" if model in DISABLE_HF_XET_MODELS else "0" "1" if model in DISABLE_HF_XET_MODELS else "0"
) )
os.environ["SGLANG_AITER_MOE"] = ( os.environ["SGLANG_USE_AITER"] = (
"0" if model in TRITON_MOE_MODELS else "1" "0" if model in TRITON_MOE_MODELS else "1"
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment