"examples/frontend_language/quick_start/local_example_chat.py" did not exist on "931213245ce69908843c731edbee7bd662f0647b"
Commit a1175a4e authored by maxiao1's avatar maxiao1
Browse files

Merge remote-tracking branch 'origin/v0.5.4_dev' into sglang_v0.5.5

parents 0c006b88 31653dd9
......@@ -168,6 +168,7 @@ MLA_ATTENTION_BACKENDS = [
"triton",
"flashmla",
"cutlass_mla",
"dcu_mla",
"trtllm_mla",
"ascend",
"nsa",
......@@ -176,6 +177,7 @@ MLA_ATTENTION_BACKENDS = [
CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS = [
"flashinfer",
"fa3",
"dcu_mla",
"fa4",
"flashmla",
"cutlass_mla",
......@@ -207,7 +209,7 @@ _is_xpu_xmx_available = xpu_has_xmx_support()
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
# Detect stragger ranks in model loading
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 3600
# the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077)
MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3
......@@ -511,6 +513,121 @@ class ModelRunner:
def model_specific_adjustment(self):
server_args = self.server_args
if (
server_args.attention_backend == "intel_amx"
and server_args.device == "cpu"
and not _is_cpu_amx_available
):
logger.info(
"The current platform does not support Intel AMX, will fallback to torch_native backend."
)
server_args.attention_backend = "torch_native"
if (
server_args.attention_backend == "intel_xpu"
and server_args.device == "xpu"
and not _is_xpu_xmx_available
):
logger.info(
"The current platform does not support Intel XMX, will fallback to triton backend."
)
server_args.attention_backend = "triton"
if server_args.prefill_attention_backend is not None and (
server_args.prefill_attention_backend
== server_args.decode_attention_backend
): # override the default attention backend
server_args.attention_backend = server_args.prefill_attention_backend
if (
getattr(self.model_config.hf_config, "dual_chunk_attention_config", None)
is not None
):
if server_args.attention_backend is None:
server_args.attention_backend = "dual_chunk_flash_attn"
logger.info("Dual chunk attention is turned on by default.")
elif server_args.attention_backend != "dual_chunk_flash_attn":
raise ValueError(
"Dual chunk attention is enabled, but attention backend is set to "
f"{server_args.attention_backend}. Please set it to 'dual_chunk_flash_attn'."
)
if server_args.attention_backend is None:
"""
Auto select the fastest attention backend.
1. Models with MHA Architecture (e.g: Llama, QWen)
1.1 We will turn on FA3 on hopper unless user use spec decode with topk > 1 or page_size > 1.
1.2 In other cases, we will use flashinfer if available, otherwise use triton.
2. Models with MLA Architecture and using FA3
2.1 We will use FA3 backend on hopper.
2.2 We will use Flashinfer backend on blackwell.
2.3 Otherwise, we will use triton backend.
"""
if not self.use_mla_backend:
# MHA architecture
if (
is_hopper_with_cuda_12_3()
and is_no_spec_infer_or_topk_one(server_args)
and is_fa3_default_architecture(self.model_config.hf_config)
):
server_args.attention_backend = "fa3"
elif _is_hip:
server_args.attention_backend = "triton"
elif _is_npu:
server_args.attention_backend = "ascend"
else:
server_args.attention_backend = (
"flashinfer" if is_flashinfer_available() else "triton"
)
else:
# MLA architecture
if is_hopper_with_cuda_12_3():
server_args.attention_backend = "fa3"
elif is_sm100_supported():
server_args.attention_backend = "flashinfer"
elif _is_hip:
head_num = self.model_config.get_num_kv_heads(self.tp_size)
# TODO current aiter only support head number 16 or 128 head number
if head_num == 128 or head_num == 16:
server_args.attention_backend = "triton"
else:
server_args.attention_backend = "triton"
elif _is_npu:
server_args.attention_backend = "ascend"
else:
server_args.attention_backend = "triton"
log_info_on_rank0(
logger,
f"Attention backend not explicitly specified. Use {server_args.attention_backend} backend by default.",
)
elif self.use_mla_backend:
if server_args.device != "cpu":
if server_args.attention_backend in MLA_ATTENTION_BACKENDS:
logger.info(
f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
)
else:
raise ValueError(
f"Invalid attention backend for MLA: {server_args.attention_backend}"
)
else:
if server_args.attention_backend != "intel_amx":
raise ValueError(
"MLA optimization not supported on CPU except for intel_amx backend."
)
if (
server_args.attention_backend == "fa3"
and server_args.kv_cache_dtype == "fp8_e5m2"
):
logger.warning(
"FlashAttention3 only supports fp8_e4m3 if using FP8; "
"Setting attention backend to triton."
)
server_args.attention_backend = "triton"
if server_args.enable_double_sparsity:
logger.info(
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
......@@ -1521,12 +1638,14 @@ class ModelRunner:
self.kv_cache_dtype = self.dtype
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
if _is_hip: # Using natively supported format
self.kv_cache_dtype = torch.float8_e5m2fnuz
# self.kv_cache_dtype = torch.float8_e5m2fnuz
self.kv_cache_dtype = torch.float8_e5m2
else:
self.kv_cache_dtype = torch.float8_e5m2
elif self.server_args.kv_cache_dtype == "fp8_e4m3":
if _is_hip: # Using natively supported format
self.kv_cache_dtype = torch.float8_e4m3fnuz
# self.kv_cache_dtype = torch.float8_e4m3fnuz
self.kv_cache_dtype = torch.float8_e4m3fn
else:
self.kv_cache_dtype = torch.float8_e4m3fn
elif self.server_args.kv_cache_dtype in ("bf16", "bfloat16"):
......
......@@ -137,6 +137,7 @@ from sglang.srt.utils import (
make_layers,
use_intel_amx_backend,
)
from sglang.srt.layers.attention.lightop_concat import concat_decode_opt
_is_hip = is_hip()
_is_cuda = is_cuda()
......@@ -147,8 +148,10 @@ _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_device_sm = get_device_sm()
_is_gfx95_supported = is_gfx95_supported()
_user_lightop_moe_sum_mul_add = get_bool_env_var("SGLANG_USE_LIGHTOP_MOE_SUM_MUL_ADD")
_use_fused_silu_mul_quant = get_bool_env_var("SGLANG_USE_FUSED_SILU_MUL_QUANT")
_use_aiter_gfx95 = _use_aiter and _is_gfx95_supported
_use_opt_cat_decode = get_bool_env_var("SGLANG_USE_OPT_CAT")
if _use_aiter_gfx95:
from sglang.srt.layers.quantization.quark.utils import quark_post_load_weights
......@@ -181,6 +184,7 @@ elif _is_hip:
from sglang.srt.layers.quantization.awq_triton import (
awq_dequantize_triton as awq_dequantize,
)
from sgl_kernel import merge_state_v2
elif _is_npu:
import custom_ops # noqa: F401
import sgl_kernel_npu # noqa: F401
......@@ -366,6 +370,10 @@ def handle_attention_flashmla(attn, forward_batch):
return _handle_attention_backend(attn, forward_batch, "flashmla")
def handle_attention_dcu_mla(attn, forward_batch):
return _handle_attention_backend(attn, forward_batch, "dcu_mla")
def handle_attention_cutlass_mla(attn, forward_batch):
return _handle_attention_backend(attn, forward_batch, "cutlass_mla")
......@@ -507,11 +515,13 @@ class DeepseekV2MLP(nn.Module):
x = (x, None, y)
gate_up, _ = self.gate_up_proj(x)
x = self.act_fn(gate_up)
x, _ = self.down_proj(
x,
skip_all_reduce=should_allreduce_fusion or use_reduce_scatter,
)
if _use_fused_silu_mul_quant:
x, _ = self.down_proj(gate_up, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter, use_fused_silu_mul_quant=True)
else:
x = self.act_fn(gate_up)
x, _ = self.down_proj(
x, skip_all_reduce=should_allreduce_fusion or use_reduce_scatter
)
return x
......@@ -811,52 +821,58 @@ class DeepseekV2MoE(nn.Module):
self.shared_experts.gate_up_proj
):
return self.forward_cpu(hidden_states, should_allreduce_fusion)
if hidden_states.shape[0] > 0:
if not self._fuse_shared_experts_inside_sbo:
shared_output = self._forward_shared_experts(
hidden_states, gemm_output_zero_allocator
)
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
topk_output = self.topk(hidden_states, router_logits)
if _user_lightop_moe_sum_mul_add:
if hidden_states.shape[0] > 0:
if not self._fuse_shared_experts_inside_sbo:
shared_output = self._forward_shared_experts(
hidden_states, gemm_output_zero_allocator
)
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
topk_output = self.topk(hidden_states, router_logits)
final_hidden_states = self.experts(hidden_states, topk_output, shared_output=shared_output)
else:
shared_output = None
topk_output = self.topk.empty_topk_output(hidden_states.device)
if hidden_states.shape[0] > 0:
if not self._fuse_shared_experts_inside_sbo:
shared_output = self._forward_shared_experts(
hidden_states, gemm_output_zero_allocator
)
# router_logits: (num_tokens, n_experts)
router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
topk_output = self.topk(hidden_states, router_logits)
else:
shared_output = None
topk_output = self.topk.empty_topk_output(hidden_states.device)
if self._fuse_shared_experts_inside_sbo:
shared_output = None
if self._fuse_shared_experts_inside_sbo:
shared_output = None
def _forward_shared_experts_and_put_results():
nonlocal shared_output
shared_output = self._forward_shared_experts(
hidden_states, gemm_output_zero_allocator
)
final_hidden_states = self.experts(
hidden_states,
topk_output,
**(
dict(
forward_shared_experts=_forward_shared_experts_and_put_results,
alt_stream=self.alt_stream,
)
if self._fuse_shared_experts_inside_sbo
else {}
),
)
if (
not _is_cuda
and not _use_aiter
or isinstance(
self.experts.quant_method, CompressedTensorsWNA16AMXEPMoEMethod
def _forward_shared_experts_and_put_results():
nonlocal shared_output
shared_output = self._forward_shared_experts(
hidden_states, gemm_output_zero_allocator
)
final_hidden_states = self.experts(
hidden_states,
topk_output,
**(
dict(
forward_shared_experts=_forward_shared_experts_and_put_results,
alt_stream=self.alt_stream,
)
if self._fuse_shared_experts_inside_sbo
else {}
),
)
or isinstance(self.experts.quant_method, CompressedTensorsWNA16MoEMethod)
):
# fused in biased_grouped_topk so we can skip here
final_hidden_states *= self.routed_scaling_factor
if shared_output is not None:
final_hidden_states += shared_output
if not _is_cuda and not _use_aiter:
# fused in biased_grouped_topk so we can skip here
final_hidden_states *= self.routed_scaling_factor
if shared_output is not None:
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
final_hidden_states_out = torch.empty_like(final_hidden_states)
torch.add(final_hidden_states, shared_output, out=final_hidden_states_out)
final_hidden_states = final_hidden_states_out
sm.tag(final_hidden_states)
if (
self.tp_size > 1
and not should_allreduce_fusion
......@@ -1766,7 +1782,10 @@ class DeepseekV2AttentionMLA(nn.Module):
self.rotary_emb.is_neox_style,
)
else:
q = torch.cat([q_nope_out, q_pe], dim=-1)
if _use_opt_cat_decode and q_nope_out.shape[0] < 1024:
q = concat_decode_opt(q_nope_out, q_pe, dim=2)
else:
q = torch.cat([q_nope_out, q_pe], dim=-1)
k = torch.cat([k_nope, k_pe], dim=-1)
attn_output = self.attn_mqa(
......@@ -2365,9 +2384,20 @@ class DeepseekV2AttentionMLA(nn.Module):
kv_indices = forward_batch.prefix_chunk_kv_indices[i]
# Fetch latent cache from memory pool with precomputed chunked kv indices
kv_a_normed, k_pe = self._get_mla_kv_buffer(
kv_indices, q.dtype, forward_batch
latent_cache_buf, dtype = forward_batch.token_to_kv_pool.get_key_buffer_DeepSeekV2(
self.attn_mha.layer_id
)
latent_cache = (
latent_cache_buf[forward_batch.prefix_chunk_kv_indices[i]]
.contiguous()
.view(dtype)
.to(q.dtype)
)
kv_a_normed, k_pe = latent_cache.split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
)
kv_a_normed = kv_a_normed.squeeze(1).contiguous()
kv = self.kv_b_proj(kv_a_normed)[0]
kv = kv.view(
-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim
......@@ -3838,6 +3868,7 @@ AttentionBackendRegistry.register("ascend", handle_attention_ascend)
AttentionBackendRegistry.register("flashinfer", handle_attention_flashinfer)
AttentionBackendRegistry.register("fa3", handle_attention_fa3)
AttentionBackendRegistry.register("flashmla", handle_attention_flashmla)
AttentionBackendRegistry.register("dcu_mla", handle_attention_dcu_mla)
AttentionBackendRegistry.register("cutlass_mla", handle_attention_cutlass_mla)
AttentionBackendRegistry.register("fa4", handle_attention_fa4)
AttentionBackendRegistry.register("trtllm_mla", handle_attention_trtllm_mla)
......
......@@ -351,7 +351,7 @@ class Qwen3GatedDeltaNet(nn.Module):
def _forward_input_proj(self, hidden_states: torch.Tensor):
DUAL_STREAM_TOKEN_THRESHOLD = 1024 if not _is_npu else 0
seq_len, _ = hidden_states.shape
if seq_len < DUAL_STREAM_TOKEN_THRESHOLD:
if seq_len < DUAL_STREAM_TOKEN_THRESHOLD and self.alt_stream is not None:
current_stream = torch.cuda.current_stream()
self.alt_stream.wait_stream(current_stream)
projected_states_qkvz, _ = self.in_proj_qkvz(hidden_states)
......
from ctypes import *
import os
import time
import threading
class Prof:
def __init__(self):
self.use_roctx = os.getenv('SGLANG_HIP_PROF') is not None
if self.use_roctx:
self.lib = cdll.LoadLibrary("libroctracer64.so")
self.lib.roctxRangePushA.argtypes = [c_char_p]
self.lib.roctxRangePushA.restype = c_int
self.lib.roctxRangePop.restype = c_int
self.tm = time.perf_counter()
self.push_depth = {}
def StartTracer(self):
if self.use_roctx:
if self.lib is None:
self.lib = cdll.LoadLibrary("libroctracer64.so")
self.lib.roctracer_start()
self.roc_tracer_flag = True
def StopTracer(self):
if self.use_roctx:
if self.lib is None:
self.lib = cdll.LoadLibrary("libroctracer64.so")
self.lib.roctracer_stop()
self.roc_tracer_flag = False
def thread_depth_add(self, num):
current_thread = threading.current_thread()
thread_id = current_thread.ident
if thread_id not in self.push_depth.keys():
self.push_depth[thread_id] = 0
if num < 0 and self.push_depth[thread_id] == 0:
return False
self.push_depth[thread_id] += num
return True
def ProfRangePush(self, message):
if profile.use_roctx and self.roc_tracer_flag:
profile.lib.roctxRangePushA(message.encode('utf-8'))
profile.lib.roctxRangePushA(message.encode('utf-8'))
self.thread_depth_add(1)
def ProfRangePop(self):
if profile.use_roctx and self.roc_tracer_flag:
if not self.thread_depth_add(-1):
return
profile.lib.roctxRangePop()
def ProfRangeAutoPush(self, message):
self.ProfRangePop()
self.ProfRangePush(message)
profile = Prof()
......@@ -103,6 +103,8 @@ QUANTIZATION_CHOICES = [
"mxfp4",
"auto-round",
"compressed-tensors", # for Ktransformers
"slimquant_w4a8_marlin",
"slimquant_marlin",
]
ATTENTION_BACKEND_CHOICES = [
......@@ -111,6 +113,8 @@ ATTENTION_BACKEND_CHOICES = [
"torch_native",
"flex_attention",
"nsa",
# ransplant from vllm
"dcu_mla",
# NVIDIA specific
"cutlass_mla",
"fa3",
......@@ -1198,9 +1202,11 @@ class ServerArgs:
if (
self.attention_backend == "flashmla"
or self.decode_attention_backend == "flashmla"
or self.attention_backend == "dcu_mla"
or self.decode_attention_backend == "dcu_mla"
):
logger.warning(
"FlashMLA only supports a page_size of 64, change page_size to 64."
"FlashMLA/DCU MLA only supports a page_size of 64, change page_size to 64."
)
self.page_size = 64
......
......@@ -46,6 +46,7 @@ class DraftBackendFactory:
else self._create_triton_decode_backend
),
"flashmla": self._create_flashmla_decode_backend,
"dcu_mla": self._create_dcumla_decode_backend,
"trtllm_mha": self._create_trtllm_mha_decode_backend,
"trtllm_mla": self._create_trtllm_mla_decode_backend,
"nsa": self._create_nsa_decode_backend,
......@@ -70,6 +71,7 @@ class DraftBackendFactory:
else self._create_triton_prefill_backend
),
"flashmla": self._create_flashmla_prefill_backend,
"dcu_mla": self._create_dcumla_prefill_backend,
"trtllm_mha": self._create_trtllm_mha_prefill_backend,
"trtllm_mla": self._create_trtllm_mla_prefill_backend,
"nsa": self._create_nsa_prefill_backend,
......@@ -151,6 +153,15 @@ class DraftBackendFactory:
return FlashMLAMultiStepDraftBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_dcumla_decode_backend(self):
from sglang.srt.layers.attention.dcu_mla_backend import (
DCUMLAMultiStepDraftBackend,
)
return DCUMLAMultiStepDraftBackend(
self.draft_model_runner, self.topk, self.speculative_num_steps
)
def _create_trtllm_mha_decode_backend(self):
from sglang.srt.layers.attention.trtllm_mha_backend import (
......@@ -240,3 +251,16 @@ class DraftBackendFactory:
"flashmla prefill backend is not yet supported for draft extend."
)
return None
def _create_dcumla_prefill_backend(self):
# logger.warning(
# "flashmla prefill backend is not yet supported for draft extend."
# )
# return None
#nhb
from sglang.srt.layers.attention.flashattention_backend import (
FlashAttentionBackend,
)
return FlashAttentionBackend(self.draft_model_runner, skip_prefill=False)
......@@ -38,9 +38,8 @@ from sglang.srt.speculative.spec_utils import (
get_src_tgt_cache_loc,
get_target_cache_loc,
)
from sglang.srt.utils import is_cuda, is_npu, next_power_of_2
_is_npu = is_npu()
from sglang.srt.utils import is_cuda, is_hip, next_power_of_2, get_bool_env_var
from sgl_kernel.kvcacheio import dcu_create_extend_after_decode_spec_info
if is_cuda():
from sgl_kernel import (
......@@ -620,6 +619,8 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
new_seq_lens: Optional[torch.Tensor] = None
verify_done: Optional[torch.cuda.Event] = None
use_sglang_create_extend_after_decode_spec_info = get_bool_env_var("SGLANG_CREATE_EXTEND_AFTER_DECODE_SPEC_INFO")
def __post_init__(self):
super().__init__(SpecInputType.EAGLE_DRAFT)
......@@ -684,14 +685,24 @@ class EagleDraftInput(SpecInput, EagleDraftInputV2Mixin):
self.positions = torch.empty_like(batch.input_ids, dtype=torch.long)
self.verified_id = torch.empty_like(self.accept_length, dtype=torch.int32)
create_extend_after_decode_spec_info[(len(batch.seq_lens),)](
batch.input_ids,
batch.seq_lens,
self.accept_length,
self.positions,
self.verified_id,
next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))),
)
if self.use_sglang_create_extend_after_decode_spec_info:
dcu_create_extend_after_decode_spec_info(
verified_id = batch.input_ids,
seq_lens = batch.seq_lens,
accept_lens = self.accept_length,
positions = self.positions,
new_verified_id = self.verified_id,
bs = max(speculative_num_steps + 1, len(batch.seq_lens)),
)
else:
create_extend_after_decode_spec_info[(len(batch.seq_lens),)](
batch.input_ids,
batch.seq_lens,
self.accept_length,
self.positions,
self.verified_id,
next_power_of_2(max(speculative_num_steps + 1, len(batch.seq_lens))),
)
def generate_attn_arg_prefill(
self,
......
......@@ -34,6 +34,12 @@ _is_cuda = is_cuda()
_is_hip = is_hip()
_is_npu = is_npu()
from sglang.srt.utils import get_bool_env_var
from sgl_kernel.kvcacheio import dcu_assign_req_to_token_pool,dcu_assign_extend_cache_locs
import logging
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.managers.tp_worker import TpModelWorker
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
......@@ -79,6 +85,9 @@ def assign_draft_cache_locs_page_size_1(
@dataclass
class EagleDraftInputV2Mixin:
use_sglang_assign_req_to_token_pool = get_bool_env_var("SGLANG_ASSIGN_REQ_TO_TOKEN_POOL")
def prepare_for_decode(self: EagleDraftInput, batch: ScheduleBatch):
from sglang.srt.speculative.spec_utils import assign_req_to_token_pool_func
......@@ -114,15 +123,26 @@ class EagleDraftInputV2Mixin:
extend_num_tokens,
)
assign_req_to_token_pool_func(
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
self.allocate_lens,
new_allocate_lens,
out_cache_loc,
bs,
)
if self.use_sglang_assign_req_to_token_pool:
dcu_assign_req_to_token_pool(
req_pool_indices = batch.req_pool_indices,
req_to_token = batch.req_to_token_pool.req_to_token,
allocate_lens = self.allocate_lens,
new_allocate_lens = new_allocate_lens,
out_cache_loc = out_cache_loc,
shape = batch.req_to_token_pool.req_to_token.shape[1],
bs = bs,
)
else:
assign_req_to_token_pool[(bs,)](
batch.req_pool_indices,
batch.req_to_token_pool.req_to_token,
self.allocate_lens,
new_allocate_lens,
out_cache_loc,
batch.req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
)
self.allocate_lens = new_allocate_lens
# FIXME(lsyin): make this sync optional
......@@ -191,6 +211,9 @@ class EagleDraftInputV2Mixin:
@dataclass
class EagleVerifyInputV2Mixin:
use_sglang_assign_extend_cache_locs = get_bool_env_var("SGLANG_ASSIGN_EXTEND_CACHE_LOCS")
def prepare_for_v2_verify(
self: EagleVerifyInput,
req_to_token_pool: ReqToTokenPool,
......@@ -211,6 +234,27 @@ class EagleVerifyInputV2Mixin:
device=device,
)
if self.use_sglang_assign_extend_cache_locs:
dcu_assign_extend_cache_locs(
batch.req_pool_indices,
req_to_token_pool.req_to_token,
batch.seq_lens,
batch.seq_lens + self.draft_token_num,
batch.out_cache_loc,
req_to_token_pool.req_to_token.shape[1],
bs,
)
else:
assign_extend_cache_locs[(bs,)](
batch.req_pool_indices,
req_to_token_pool.req_to_token,
batch.seq_lens,
batch.seq_lens + self.draft_token_num,
batch.out_cache_loc,
req_to_token_pool.req_to_token.shape[1],
next_power_of_2(bs),
)
# Get a forward batch
batch.forward_mode = ForwardMode.TARGET_VERIFY
batch.capture_hidden_mode = CaptureHiddenMode.FULL
......
......@@ -165,10 +165,10 @@ DINLINE void start_sync(
if (threadIdx.x < ngpus) {
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__scoped_atomic_store_n(
&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __MEMORY_SCOPE_SYSTEM);
__hip_atomic_store(
&sg.signals[threadIdx.x]->start[blockIdx.x][rank], flag, __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_SYSTEM);
// wait until we got true from all ranks
while (__scoped_atomic_load_n(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __MEMORY_SCOPE_DEVICE) <
while (__hip_atomic_load(&self_sg->start[blockIdx.x][threadIdx.x], __ATOMIC_RELAXED, __HIP_MEMORY_SCOPE_AGENT) <
flag)
;
}
......@@ -211,16 +211,16 @@ DINLINE void end_sync(
if (threadIdx.x < ngpus) {
// simultaneously write to the corresponding flag of all ranks.
// Latency = 1 p2p write
__scoped_atomic_store_n(
__hip_atomic_store(
&sg.signals[threadIdx.x]->end[blockIdx.x][rank],
flag,
final_sync ? __ATOMIC_RELAXED : __ATOMIC_RELEASE,
__MEMORY_SCOPE_SYSTEM);
__HIP_MEMORY_SCOPE_SYSTEM);
// wait until we got true from all ranks
while (__scoped_atomic_load_n(
while (__hip_atomic_load(
&self_sg->end[blockIdx.x][threadIdx.x],
final_sync ? __ATOMIC_RELAXED : __ATOMIC_ACQUIRE,
__MEMORY_SCOPE_DEVICE) < flag)
__HIP_MEMORY_SCOPE_AGENT) < flag)
;
}
__syncthreads();
......
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#ifdef __HIP_PLATFORM_AMD__
#include <hip/hip_bf16.h>
#endif
#include <algorithm>
#include <optional>
#include "pytorch_extension_utils.h"
#include "pytorch_extension_utils_rocm.h"
// Helper functions to convert between different data types
// (float, half, bfloat16) for the merge attention states kernel.
......@@ -27,6 +31,19 @@ inline __device__ void from_float(__nv_bfloat16& d, float s) {
d = __float2bfloat16(s);
}
inline void check_shape(const at::Tensor& a, const at::Tensor& b, const char* a_name, const char* b_name) {
TORCH_CHECK(a.dim() == b.dim(), a_name, ".dim() != ", b_name, ".dim(). ", a.dim(), " vs ", b.dim());
for (int i = 0; i < a.dim(); ++i) {
TORCH_CHECK(a.size(i) == b.size(i), a_name, ".size(", i, ") != ", b_name, ".size(", i, ")");
}
}
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
#define CHECK_EQ(a, b) TORCH_CHECK((a) == (b), "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
#define CHECK_GE(a, b) TORCH_CHECK((a) >= (b), "CHECK_GE(" #a ", " #b ") failed. ", a, " vs ", b)
// Implements section 2.2 of https://www.arxiv.org/pdf/2501.01005
template <typename scalar_t, const uint NUM_THREADS>
__global__ void merge_attn_states_kernel(
......
......@@ -19,6 +19,14 @@ limitations under the License.
#include "sgl_kernel_ops.h"
TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/*
* From FlashMLA
*/
m.def("dcu_create_flashmla_kv_indices(Tensor req_to_token, Tensor req_pool_indices,Tensor page_kernel_lens, Tensor? kv_start_idx, Tensor kv_indices, int req_to_token_stride, int kv_indices_stride, int PAGED_SIZE) -> ()");
m.impl("dcu_create_flashmla_kv_indices", torch::kCUDA, &dcu_create_flashmla_kv_indices);
/*
* From csrc/activation
*/
......@@ -34,6 +42,12 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
m.def("gelu_quick(Tensor! out, Tensor input) -> ()");
m.impl("gelu_quick", torch::kCUDA, &gelu_quick);
/*
* From csrc/attention
*/
m.def("merge_state_v2(Tensor v_a, Tensor s_a, Tensor v_b, Tensor s_b, Tensor! v_merged, Tensor! s_merged) -> ()");
m.impl("merge_state_v2", torch::kCUDA, &merge_state_v2);
/*
* From csrc/allreduce
*/
......@@ -119,6 +133,22 @@ TORCH_LIBRARY_EXPAND(sgl_kernel, m) {
/*
* From csrc/kvcacheio
*/
m.def("dcu_create_extend_after_decode_spec_info(Tensor verified_id, Tensor seq_lens, Tensor accept_lens, Tensor positions, Tensor new_verified_id, int bs) -> ()");
m.impl("dcu_create_extend_after_decode_spec_info", torch::kCUDA, &dcu_create_extend_after_decode_spec_info);
m.def("dcu_create_chunked_prefix_cache_kv_indices(Tensor req_to_token, Tensor req_pool_indices, Tensor chunk_starts, Tensor chunk_seq_lens, Tensor chunk_cu_seq_lens, Tensor chunk_kv_indices, int col_num, int bs) -> ()");
m.impl("dcu_create_chunked_prefix_cache_kv_indices", torch::kCUDA, &dcu_create_chunked_prefix_cache_kv_indices);
m.def("dcu_assign_extend_cache_locs(Tensor req_pool_indices, Tensor req_to_token, Tensor start_offset, Tensor end_offset, Tensor out_cache_loc, int pool_len, int bs) -> ()");
m.impl("dcu_assign_extend_cache_locs", torch::kCUDA, &dcu_assign_extend_cache_locs);
m.def("dcu_get_last_loc(Tensor req_to_token, Tensor req_pool_indices, Tensor prefix_lens) -> Tensor");
m.impl("dcu_get_last_loc", torch::kCUDA, &dcu_get_last_loc);
m.def("dcu_assign_req_to_token_pool(Tensor req_pool_indices_ptr,Tensor req_to_token_ptr,Tensor allocate_lens_ptr,Tensor new_allocate_lens,Tensor out_cache_loc_ptr,int shape,int bs) -> ()");
m.impl("dcu_assign_req_to_token_pool",torch::kCUDA,&dcu_assign_req_to_token_pool);
m.def("dcu_alloc_extend_kernel(Tensor pre_lens_ptr, Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()");
m.impl("dcu_alloc_extend_kernel", torch::kCUDA, &dcu_alloc_extend_kernel);
m.def("dcu_alloc_decode_kernel(Tensor seq_lens_ptr, Tensor last_loc_ptr, Tensor free_page_ptr, Tensor out_indices, int bs, int page_size) -> ()");
m.impl("dcu_alloc_decode_kernel", torch::kCUDA, &dcu_alloc_decode_kernel);
m.def(
"transfer_kv_per_layer(Tensor src_k, Tensor dst_k, Tensor src_v, Tensor dst_v, Tensor src_indices, Tensor "
"dst_indices, int item_size, int block_quota, int num_warps_per_block) -> ()");
......
......@@ -25,7 +25,7 @@
#define INTRIN_M 16
#define INTRIN_N 16
#define INTRIN_K 32
#define WARP_SIZE 32
#define WARP_SIZE 64
#define SMEM_PAD_A 0
#define SMEM_PAD_B 0
#define PACK_SIZE 16
......
......@@ -25,7 +25,7 @@
#define INTRIN_M 16
#define INTRIN_N 16
#define INTRIN_K 32
#define WARP_SIZE 32
#define WARP_SIZE 64
#define SMEM_PAD_A 0
#define SMEM_PAD_B 0
#define PACK_SIZE 16
......
......@@ -5,7 +5,7 @@
#include <cstdint>
#ifndef USE_ROCM
#define WARP_SIZE 32
#define WARP_SIZE 64
#include "pytorch_extension_utils.h"
#else
#include "pytorch_extension_utils_rocm.h"
......@@ -805,3 +805,587 @@ void transfer_kv_all_layer_direct_lf_pf(
int64_t page_size) {
transfer_kv_page_first_direct_impl<true>(src_ptrs, dst_ptrs, src_indices, dst_indices, 0, page_size);
}
__device__ int64_t ceil_div(int64_t a, int64_t b) {
return (a + b - 1) / b;
}
__device__ int64_t safe_min(int64_t a, int64_t b) {
return a < b ? a : b;
}
__global__ void launch_alloc_decode_kernel(
const int64_t* seq_lens_ptr,
const int32_t* last_loc_ptr,
const int64_t* free_page_ptr,
int64_t* out_indices,
int64_t bs,
int64_t page_size) {
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs) return;
int64_t seq_len = seq_lens_ptr[pid];
int64_t pre_len = seq_len - 1;
int64_t num_page_start_loc_self = ceil_div(seq_len, page_size) - ceil_div(pre_len, page_size);
int64_t sum_num_new_pages = 0;
for (int64_t i = 0; i <= pid; i++) {
int64_t other_seq_len = seq_lens_ptr[i];
int64_t other_pre_len = (i <= pid) ? (other_seq_len - 1) : other_seq_len;
int64_t other_num_pages_after = ceil_div(other_seq_len, page_size);
int64_t other_num_pages_before = ceil_div(other_pre_len, page_size);
int64_t other_num_new_pages = other_num_pages_after - other_num_pages_before;
sum_num_new_pages += other_num_new_pages;
}
int64_t new_page_start_loc = sum_num_new_pages - num_page_start_loc_self;
if (num_page_start_loc_self == 0) {
int32_t last_loc = last_loc_ptr[pid];
out_indices[pid] = last_loc + 1;
} else {
int64_t page = free_page_ptr[new_page_start_loc];
out_indices[pid] = page * page_size;
}
}
__global__ void launch_alloc_extend_kernel(
const int64_t* pre_lens_ptr,
const int64_t* seq_lens_ptr,
const int64_t* last_loc_ptr,
const int64_t* free_page_ptr,
int64_t* out_indices,
int64_t bs,
int64_t page_size)
{
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs) return;
int64_t seq_len = seq_lens_ptr[pid];
int64_t pre_len = pre_lens_ptr[pid];
int64_t extend_len = seq_len - pre_len;
int64_t sum_extend_lens = 0;
for (int64_t i = 0; i <= pid; i++) {
int64_t other_seq_len = seq_lens_ptr[i];
int64_t other_pre_len = pre_lens_ptr[i];
int64_t other_extend_len = other_seq_len - other_pre_len;
sum_extend_lens += other_extend_len;
}
int64_t output_start_loc = sum_extend_lens - extend_len;
int64_t num_page_start_loc_self = ceil_div(seq_len, page_size) - ceil_div(pre_len, page_size);
int64_t sum_num_new_pages = 0;
for (int64_t i = 0; i <= pid; i++) {
int64_t other_seq_len = seq_lens_ptr[i];
int64_t other_pre_len = pre_lens_ptr[i];
int64_t other_num_pages_after = ceil_div(other_seq_len, page_size);
int64_t other_num_pages_before = ceil_div(other_pre_len, page_size);
int64_t other_num_new_pages = other_num_pages_after - other_num_pages_before;
sum_num_new_pages += other_num_new_pages;
}
int64_t new_page_start_loc = sum_num_new_pages - num_page_start_loc_self;
int64_t last_loc = last_loc_ptr[pid];
int64_t num_part1 = safe_min(seq_len, ceil_div(pre_len, page_size) * page_size) - pre_len;
for (int64_t offset = 0; offset < num_part1 && offset < page_size; offset++) {
int64_t output_idx = output_start_loc + offset;
out_indices[output_idx] = last_loc + 1 + offset;
}
if (pre_len + num_part1 == seq_len) {
return;
}
int64_t num_part2 = (seq_len / page_size) * page_size - ceil_div(pre_len, page_size) * page_size;
for (int64_t offset = 0; offset < num_part2; offset++) {
int64_t page_idx = new_page_start_loc + offset / page_size;
int64_t page_start = free_page_ptr[page_idx];
int64_t output_idx = output_start_loc + num_part1 + offset;
out_indices[output_idx] = page_start * page_size + offset % page_size;
}
if (pre_len + num_part1 + num_part2 == seq_len) {
return;
}
int64_t num_part3 = seq_len - (seq_len / page_size) * page_size;
int64_t last_page_idx = new_page_start_loc + num_page_start_loc_self - 1;
int64_t start_loc = free_page_ptr[last_page_idx];
for (int64_t offset = 0; offset < num_part3 && offset < page_size; offset++) {
int64_t output_idx = output_start_loc + num_part1 + num_part2 + offset;
out_indices[output_idx] = start_loc * page_size + offset;
}
}
__global__ void launch_create_extend_after_decode_spec_info_int32_kernel(
const int32_t* verified_id_ptr,
const int64_t* seq_lens_ptr,
const int32_t* accept_lens_ptr,
int64_t* positions_ptr,
int32_t* new_verified_id_ptr,
int64_t bs) {
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs) return;
int64_t seq_length = seq_lens_ptr[pid];
int32_t accept_length = accept_lens_ptr[pid];
int32_t accept_len_cumsum = 0;
for (int32_t offset = 0; offset < pid; offset++) {
accept_len_cumsum += accept_lens_ptr[offset];
}
int64_t* positions_ptr1 = positions_ptr + accept_len_cumsum;
for (int32_t offset = 0; offset < accept_length && offset < bs; offset++)
{
positions_ptr1[offset] = seq_length - accept_length + offset;
}
int32_t verified_idx = accept_len_cumsum + accept_length - 1;
new_verified_id_ptr[pid] = verified_id_ptr[verified_idx];
}
__global__ void launch_create_extend_after_decode_spec_info_int64_kernel(
const int32_t* verified_id_ptr,
const int64_t* seq_lens_ptr,
const int64_t* accept_lens_ptr,
int64_t* positions_ptr,
int32_t* new_verified_id_ptr,
int64_t bs) {
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs) return;
int64_t seq_length = seq_lens_ptr[pid];
int64_t accept_length = accept_lens_ptr[pid];
int64_t accept_len_cumsum = 0;
for (int64_t offset = 0; offset < pid; offset++) {
accept_len_cumsum += accept_lens_ptr[offset];
}
int64_t* positions_ptr1 = positions_ptr + accept_len_cumsum;
for (int64_t offset = 0; offset < accept_length && offset < bs; offset++)
{
positions_ptr1[offset] = seq_length - accept_length + offset;
}
int64_t verified_idx = accept_len_cumsum + accept_length - 1;
new_verified_id_ptr[pid] = verified_id_ptr[verified_idx];
}
void dcu_alloc_decode_kernel(
const at::Tensor seq_lens_ptr,
const at::Tensor last_loc_ptr,
const at::Tensor free_page_ptr,
at::Tensor out_indices,
int64_t bs,
int64_t page_size) {
const int64_t* seq_lens_ptr1 = static_cast<const int64_t*>(seq_lens_ptr.data_ptr());
const int32_t* last_loc_ptr1 = static_cast<const int32_t*>(last_loc_ptr.data_ptr());
const int64_t* free_page_ptr1 = static_cast<const int64_t*>(free_page_ptr.data_ptr());
int64_t* out_indices1 = static_cast<int64_t*>(out_indices.data_ptr());
int64_t block_size = 64;
int64_t grid_size = (bs + block_size - 1) / block_size;
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
launch_alloc_decode_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs, page_size);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
void dcu_create_extend_after_decode_spec_info(
const at::Tensor verified_id,
const at::Tensor seq_lens,
const at::Tensor accept_lens,
at::Tensor positions,
at::Tensor new_verified_id,
int64_t bs) {
const int32_t* verified_id_ptr;
const int64_t* seq_lens_ptr;
const int32_t* accept_lens_ptr_int32;
const int64_t* accept_lens_ptr_int64;
int64_t* positions_ptr;
int32_t* new_verified_id_ptr;
int64_t block_size = 64;
int64_t grid_size = (bs + block_size - 1) / block_size;
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
if (accept_lens.dtype() == torch::kInt32)
{
verified_id_ptr = static_cast<const int32_t*>(verified_id.data_ptr());
seq_lens_ptr = static_cast<const int64_t*>(seq_lens.data_ptr());
accept_lens_ptr_int32 = static_cast<const int32_t*>(accept_lens.data_ptr());
positions_ptr = static_cast<int64_t*>(positions.data_ptr());
new_verified_id_ptr = static_cast<int32_t*>(new_verified_id.data_ptr());
launch_create_extend_after_decode_spec_info_int32_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(verified_id_ptr, seq_lens_ptr, accept_lens_ptr_int32, positions_ptr, new_verified_id_ptr, bs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
else
{
verified_id_ptr = static_cast<const int32_t*>(verified_id.data_ptr());
seq_lens_ptr = static_cast<const int64_t*>(seq_lens.data_ptr());
accept_lens_ptr_int64 = static_cast<const int64_t*>(accept_lens.data_ptr());
positions_ptr = static_cast<int64_t*>(positions.data_ptr());
new_verified_id_ptr = static_cast<int32_t*>(new_verified_id.data_ptr());
launch_create_extend_after_decode_spec_info_int64_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(verified_id_ptr, seq_lens_ptr, accept_lens_ptr_int64, positions_ptr, new_verified_id_ptr, bs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
};
void dcu_alloc_extend_kernel(
const at::Tensor pre_lens_ptr,
const at::Tensor seq_lens_ptr,
const at::Tensor last_loc_ptr,
const at::Tensor free_page_ptr,
at::Tensor out_indices,
int64_t bs,
int64_t page_size) {
const int64_t* pre_lens_ptr1 = static_cast<const int64_t*>(pre_lens_ptr.data_ptr());
const int64_t* seq_lens_ptr1 = static_cast<const int64_t*>(seq_lens_ptr.data_ptr());
const int64_t* last_loc_ptr1 = static_cast<const int64_t*>(last_loc_ptr.data_ptr());
const int64_t* free_page_ptr1 = static_cast<const int64_t*>(free_page_ptr.data_ptr());
int64_t* out_indices1 = static_cast<int64_t*>(out_indices.data_ptr());
int64_t block_size = 64;
int64_t grid_size = (bs + block_size - 1) / block_size;
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
launch_alloc_extend_kernel<<<grid_size, block_size, 0, torch_current_stream>>>(pre_lens_ptr1, seq_lens_ptr1, last_loc_ptr1, free_page_ptr1, out_indices1, bs, page_size);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
__global__ void launch_assign_req_to_token_pool(
const int64_t* req_pool_indices_ptr,
int32_t* req_to_token_ptr,
const int64_t* allocate_lens_ptr,
int64_t* new_allocate_lens,
int64_t* out_cache_loc_ptr,
int64_t shape,
int64_t bs)
{
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs) return;
int64_t kv_start = allocate_lens_ptr[pid];
int64_t kv_end = new_allocate_lens[pid];
int64_t pool_idx = req_pool_indices_ptr[pid];
int32_t* token_pool = (int32_t*)(req_to_token_ptr + pool_idx * shape);
int64_t sum_out_offset = 0;
for(int length_offset = 0; length_offset < pid;length_offset++){
int64_t start = allocate_lens_ptr[length_offset];
int64_t end = new_allocate_lens[length_offset];
sum_out_offset += (end- start);
}
int64_t* out_cache_ptr = out_cache_loc_ptr + sum_out_offset;
int64_t copy_length = kv_end - kv_start;
#pragma unroll(32)
for (int out_cache_index = 0; out_cache_index < copy_length; out_cache_index++) {
token_pool[kv_start + out_cache_index] = out_cache_ptr[out_cache_index];
}
}
void dcu_assign_req_to_token_pool(
const at::Tensor req_pool_indices_ptr,
at::Tensor req_to_token_ptr,
const at::Tensor allocate_lens_ptr,
at::Tensor new_allocate_lens,
at::Tensor out_cache_loc_ptr,
int64_t shape,
int64_t bs) {
const int64_t* req_pool_indices_ptr1 = static_cast<const int64_t*>(req_pool_indices_ptr.data_ptr());
int32_t* req_to_token_ptr1 = static_cast<int32_t*>(req_to_token_ptr.data_ptr());
const int64_t* allocate_lens_ptr1 = static_cast<const int64_t*>(allocate_lens_ptr.data_ptr());
int64_t* new_allocate_lens1 = static_cast<int64_t*>(new_allocate_lens.data_ptr());
int64_t* out_cache_loc_ptr1 = static_cast<int64_t*>(out_cache_loc_ptr.data_ptr());
int64_t block_size = 64;
int64_t grid_size = (bs + block_size - 1) / block_size;
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
launch_assign_req_to_token_pool<<<grid_size, block_size, 0, torch_current_stream>>>(req_pool_indices_ptr1, req_to_token_ptr1, allocate_lens_ptr1, new_allocate_lens1, out_cache_loc_ptr1, shape, bs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
__global__ void get_last_loc_kernel(
const int32_t* __restrict__ req_to_token,
const int64_t* __restrict__ req_pool_indices_tensor,
const int64_t* __restrict__ prefix_lens_tensor,
int64_t* __restrict__ result,
int64_t num_tokens,
int64_t req_to_token_stride){
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= num_tokens) return;
int64_t pre_len = prefix_lens_tensor[pid];
if (pre_len > 0) {
int64_t req_idx = req_pool_indices_tensor[pid];
int64_t token_idx = req_idx * req_to_token_stride + (pre_len - 1);
result[pid] = static_cast<int64_t>(req_to_token[token_idx]);
} else {
result[pid] = static_cast<int64_t>(-1);
}
}
at::Tensor dcu_get_last_loc(
const at::Tensor req_to_token,
const at::Tensor req_pool_indices,
const at::Tensor prefix_lens) {
TORCH_CHECK(req_to_token.device().is_cuda(), "req_to_token must be CUDA tensor");
TORCH_CHECK(req_pool_indices.device().is_cuda(), "req_pool_indices must be CUDA tensor");
TORCH_CHECK(prefix_lens.device().is_cuda(), "prefix_lens must be CUDA tensor");
TORCH_CHECK(req_to_token.dim() == 2, "req_to_token must be 2D tensor [batch, seq_len]");
TORCH_CHECK(prefix_lens.dim() == 1, "prefix_lens must be 1D");
TORCH_CHECK(req_pool_indices.dim() == 1, "req_pool_indices must be 1D");
int64_t num_tokens = prefix_lens.numel();
TORCH_CHECK(req_pool_indices.numel() == num_tokens, "req_pool_indices must have same length as prefix_lens");
int64_t req_to_token_stride = req_to_token.stride(0);
auto req_to_token_c = req_to_token.contiguous();
auto req_pool_indices_c = req_pool_indices.contiguous();
auto prefix_lens_c = prefix_lens.contiguous();
const int32_t* req_to_token_ptr = req_to_token_c.data_ptr<int32_t>();
const int64_t* req_pool_indices_ptr = req_pool_indices_c.data_ptr<int64_t>();
const int64_t* prefix_lens_ptr = prefix_lens_c.data_ptr<int64_t>();
auto result = at::empty_like(prefix_lens_c);
int64_t* result_ptr = result.data_ptr<int64_t>();
const int64_t block_size = 64;
const int64_t grid_size = (num_tokens + block_size - 1) / block_size;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
get_last_loc_kernel<<<grid_size, block_size, 0, stream>>>(
req_to_token_ptr,
req_pool_indices_ptr,
prefix_lens_ptr,
result_ptr,
num_tokens,
req_to_token_stride
);
C10_CUDA_KERNEL_LAUNCH_CHECK();
return result;
}
__global__ void launch_assign_extend_cache_locs_kernel(
const int64_t* __restrict__ req_pool_indices, // [bs]
const int32_t* __restrict__ req_to_token, // [max_num_req, pool_len]
const int64_t* __restrict__ start_offset, // [bs]
const int64_t* __restrict__ end_offset, // [bs]
int64_t* __restrict__ out_cache_loc, // [sum(draft_token_num)]
int64_t pool_len,
int64_t bs)
{
int pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs) return;
int64_t kv_start = start_offset[pid];
int64_t kv_end = end_offset[pid];
int64_t req_id = req_pool_indices[pid];
int64_t out_offset = 0;
for (int i = 0; i < pid; ++i) {
out_offset += end_offset[i] - start_offset[i];
}
const int32_t* src = req_to_token + req_id * pool_len + kv_start;
int64_t* dst = out_cache_loc + out_offset;
for (int64_t i = 0; i < kv_end - kv_start; ++i) {
dst[i] = src[i];
}
}
void dcu_assign_extend_cache_locs(
const at::Tensor req_pool_indices,
const at::Tensor req_to_token,
const at::Tensor start_offset,
const at::Tensor end_offset,
at::Tensor out_cache_loc,
int64_t pool_len,
int64_t bs)
{
const int64_t* req_pool_indices_ptr = req_pool_indices.data_ptr<int64_t>();
const int32_t* req_to_token_ptr = req_to_token.data_ptr<int32_t>();
const int64_t* start_offset_ptr = start_offset.data_ptr<int64_t>();
const int64_t* end_offset_ptr = end_offset.data_ptr<int64_t>();
int64_t* out_cache_loc_ptr = out_cache_loc.data_ptr<int64_t>();
constexpr int64_t threads = 128;
int64_t blocks = (bs + threads - 1) / threads;
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
launch_assign_extend_cache_locs_kernel<<<blocks, threads, 0, stream>>>(
req_pool_indices_ptr,
req_to_token_ptr,
start_offset_ptr,
end_offset_ptr,
out_cache_loc_ptr,
pool_len,
bs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
template<int PAGED_SIZE>
__global__ void dcu_create_flashmla_kv_indices_kernel(
const int32_t* __restrict__ req_to_token,
const int32_t* __restrict__ req_pool_indices,
const int32_t* __restrict__ page_kernel_lens,
const int32_t* __restrict__ kv_start_idx,
int32_t* __restrict__ kv_indices,
int req_to_token_stride,
int kv_indices_stride)
{
int pid = blockIdx.x; // batch index
int req_pool_index = req_pool_indices[pid];
int kv_start = 0;
int kv_end = 0;
if (kv_start_idx != nullptr) {
kv_start = kv_start_idx[pid];
kv_end = kv_start;
}
kv_end += page_kernel_lens[pid];
int total_len = kv_end - kv_start;
int num_pages = (total_len + PAGED_SIZE - 1) / PAGED_SIZE;
for (int pg = 0; pg < num_pages; ++pg) {
int offset = pg * PAGED_SIZE;
// token id = req_to_token[req_pool_index][kv_start + offset]
int64_t token =
req_to_token[req_pool_index * req_to_token_stride + kv_start + offset];
// 页索引
kv_indices[pid * kv_indices_stride + pg] = token / PAGED_SIZE;
}
}
void dcu_create_flashmla_kv_indices(
const at::Tensor& req_to_token,
const at::Tensor& req_pool_indices,
const at::Tensor& page_kernel_lens,
const c10::optional<at::Tensor>& kv_start_idx,
at::Tensor& kv_indices,
int64_t req_to_token_stride,
int64_t kv_indices_stride,
int64_t PAGED_SIZE)
{
TORCH_CHECK(req_to_token.is_cuda(), "req_to_token must be CUDA tensor");
TORCH_CHECK(kv_indices.is_cuda(), "kv_indices must be CUDA tensor");
int bs = req_pool_indices.size(0);
auto stream = at::cuda::getCurrentCUDAStream();
dim3 grid(bs);
dim3 block(1);
const int32_t* kv_start_idx_ptr = nullptr;
if (kv_start_idx.has_value()) {
kv_start_idx_ptr = kv_start_idx.value().data_ptr<int32_t>();
}
if (PAGED_SIZE == 64) {
dcu_create_flashmla_kv_indices_kernel<64><<<grid, block, 0, stream>>>(
req_to_token.data_ptr<int32_t>(),
req_pool_indices.data_ptr<int32_t>(),
page_kernel_lens.data_ptr<int32_t>(),
kv_start_idx_ptr,
kv_indices.data_ptr<int32_t>(),
req_to_token_stride,
kv_indices_stride
);
} else {
TORCH_CHECK(false, "Unsupported PAGED_SIZE");
}
}
__global__ void launch_create_chunked_prefix_cache_kv_indices(
int32_t* req_to_token_ptr,
const int64_t* req_pool_indices_ptr,
const int32_t* chunk_starts_ptr,
const int32_t* chunk_seq_lens_ptr,
const int32_t* chunk_cu_seq_lens_ptr,
int32_t* chunk_kv_indices_ptr,
int64_t col_num,
int64_t bs)
{
int64_t pid = blockIdx.x * blockDim.x + threadIdx.x;
if (pid >= bs) return;
int64_t req_pool_index = req_pool_indices_ptr[pid];
int64_t chunk_kv_indices_offset = chunk_cu_seq_lens_ptr[pid];
int32_t chunk_start_pos = chunk_starts_ptr[pid];
int32_t chunk_seq_len = chunk_seq_lens_ptr[pid];
#pragma unroll(32)
for(int32_t offset = 0;offset < chunk_seq_len;offset++){
chunk_kv_indices_ptr[chunk_kv_indices_offset+offset] = req_to_token_ptr[req_pool_index * col_num + chunk_start_pos + offset];
}
}
void dcu_create_chunked_prefix_cache_kv_indices(
at::Tensor req_to_token_ptr,
const at::Tensor req_pool_indices_ptr,
const at::Tensor chunk_starts_ptr,
const at::Tensor chunk_seq_lens_ptr,
const at::Tensor chunk_cu_seq_lens_ptr,
at::Tensor chunk_kv_indices_ptr,
int64_t col_num,
int64_t bs) {
int32_t* req_to_token_ptr1 = static_cast<int32_t*>(req_to_token_ptr.data_ptr());
const int64_t* req_pool_indices_ptr1 = static_cast<const int64_t*>(req_pool_indices_ptr.data_ptr());
const int32_t* chunk_starts_ptr1 = static_cast<const int32_t*>(chunk_starts_ptr.data_ptr());
const int32_t* chunk_seq_lens_ptr1 = static_cast<const int32_t*>(chunk_seq_lens_ptr.data_ptr());
const int32_t* chunk_cu_seq_lens_ptr1 = static_cast<const int32_t*>(chunk_cu_seq_lens_ptr.data_ptr());
int32_t* chunk_kv_indices_ptr1 = static_cast<int32_t*>(chunk_kv_indices_ptr.data_ptr());
int64_t block_size = 64;
int64_t grid_size = (bs + block_size - 1) / block_size;
cudaStream_t torch_current_stream = at::cuda::getCurrentCUDAStream();
launch_create_chunked_prefix_cache_kv_indices<<<grid_size, block_size, 0, torch_current_stream>>>(req_to_token_ptr1, req_pool_indices_ptr1, chunk_starts_ptr1, chunk_seq_lens_ptr1, chunk_cu_seq_lens_ptr1,chunk_kv_indices_ptr1, col_num, bs);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
......@@ -21,6 +21,7 @@ limitations under the License.
#include "utils.h"
#define WARP_SIZE 64
#define VEC_SIZE 4
using Vec = int4;
......@@ -45,7 +46,7 @@ __device__ __forceinline__ int warp_exclusive_scan(int v, unsigned mask = 0xffff
int original = v;
#pragma unroll
for (int offset = 1; offset < WARP_SIZE; offset <<= 1) {
int n = __shfl_up_sync(mask, v, offset);
int n = __shfl_up(v, offset);
if ((threadIdx.x & (WARP_SIZE - 1)) >= offset) v += n;
}
return v - original;
......
......@@ -60,7 +60,7 @@ template <typename T>
__device__ float convert_to_float(T x) {
if constexpr (std::is_same_v<T, __half>) {
return __half2float(x);
} else if constexpr (std::is_same_v<T, __nv_bfloat16>) {
} else if constexpr (std::is_same_v<T, __hip_bfloat16>) {
return __bfloat162float(x);
} else if constexpr (std::is_same_v<T, float>) {
return x;
......@@ -686,8 +686,8 @@ void topk_softmax(
bias_ptr,
stream);
} else if (dtype == at::ScalarType::BFloat16) {
topkGatingSoftmaxKernelLauncher<__nv_bfloat16>(
reinterpret_cast<const __nv_bfloat16*>(gating_output.data_ptr<at::BFloat16>()),
topkGatingSoftmaxKernelLauncher<__hip_bfloat16>(
reinterpret_cast<const __hip_bfloat16*>(gating_output.data_ptr<at::BFloat16>()),
topk_weights.data_ptr<float>(),
topk_indices.data_ptr<int>(),
softmax_workspace.data_ptr<float>(),
......
......@@ -3,7 +3,7 @@
// copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-common.h
#define QK_K 256
#define K_QUANTS_PER_ITERATION 2
#define WARP_SIZE_GGUF 32
#define WARP_SIZE_GGUF 64
#define K_SCALE_SIZE 12
#define CUDA_DEQUANTIZE_BLOCK_SIZE 256
#define CUDA_QUANTIZE_BLOCK_SIZE 256
......
......@@ -515,6 +515,75 @@ void segment_packbits(
/*
* From csrc/kvcacheio
*/
void dcu_create_extend_after_decode_spec_info(
const at::Tensor verified_id,
const at::Tensor seq_lens,
const at::Tensor accept_lens,
at::Tensor positions,
at::Tensor new_verified_id,
int64_t bs);
void dcu_create_chunked_prefix_cache_kv_indices(
at::Tensor req_to_token,
const at::Tensor req_pool_indices,
const at::Tensor chunk_starts,
const at::Tensor chunk_seq_lens,
const at::Tensor chunk_cu_seq_lens,
at::Tensor chunk_kv_indices,
int64_t col_num,
int64_t bs);
void dcu_create_flashmla_kv_indices(
const at::Tensor& req_to_token,
const at::Tensor& req_pool_indices,
const at::Tensor& page_kernel_lens,
const c10::optional<at::Tensor>& kv_start_idx,
at::Tensor& kv_indices,
int64_t req_to_token_stride,
int64_t kv_indices_stride,
int64_t PAGED_SIZE);
void dcu_assign_extend_cache_locs(
const at::Tensor req_pool_indices,
const at::Tensor req_to_token,
const at::Tensor start_offset,
const at::Tensor end_offset,
at::Tensor out_cache_loc,
int64_t pool_len,
int64_t bs);
at::Tensor dcu_get_last_loc(
const at::Tensor req_to_token,
const at::Tensor req_pool_indices,
const at::Tensor prefix_lens);
void dcu_assign_req_to_token_pool(
const at::Tensor req_pool_indices_ptr,
at::Tensor req_to_token_ptr,
const at::Tensor allocate_lens_ptr,
at::Tensor new_allocate_lens,
at::Tensor out_cache_loc_ptr,
int64_t shape,
int64_t bs);
void dcu_alloc_extend_kernel(
const at::Tensor pre_lens_ptr,
const at::Tensor seq_lens_ptr,
const at::Tensor last_loc_ptr,
const at::Tensor free_page_ptr,
at::Tensor out_indices,
int64_t bs,
int64_t page_size);
void dcu_alloc_decode_kernel(
const at::Tensor seq_lens_ptr,
const at::Tensor last_loc_ptr,
const at::Tensor free_page_ptr,
at::Tensor out_indices,
int64_t bs,
int64_t page_size);
void transfer_kv_per_layer(
const at::Tensor src_k,
at::Tensor dst_k,
......
......@@ -340,7 +340,7 @@ inline bool getEnvEnablePDL() {
#define CEILDIV(x, y) (((x) + (y) - 1) / (y))
#ifndef USE_ROCM
#define WARP_SIZE 32
#define WARP_SIZE 64
#else
#if defined(__GFX9__) || !defined(__HIP_DEVICE_COMPILE__)
#define WARP_SIZE 64
......@@ -369,25 +369,25 @@ __device__ __forceinline__ dstDtype castFromFloat(float val) {
#endif
// add FP8 support
#ifndef USE_ROCM
#include <c10/util/Float8_e4m3fn.h>
using FP8_TYPE = c10::Float8_e4m3fn;
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
#else // USE_ROCM
#if HIP_FP8_TYPE_FNUZ
#include <c10/util/Float8_e4m3fnuz.h>
using FP8_TYPE = c10::Float8_e4m3fnuz;
constexpr auto FP8_E4M3_MAX = 224.0f;
#else
#if HIP_FP8_TYPE_E4M3
#include <c10/util/Float8_e4m3fn.h>
using FP8_TYPE = c10::Float8_e4m3fn;
C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
#else
#error "fp8 is not supported in this processor (arch < gfx942)."
#endif // HIP_FP8_TYPE_E4M3
#endif // HIP_FP8_TYPE_FNUZ
#endif // USE_ROCM
// #ifndef USE_ROCM
// #include <c10/util/Float8_e4m3fn.h>
// using FP8_TYPE = c10::Float8_e4m3fn;
// C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
// #else // USE_ROCM
// #if HIP_FP8_TYPE_FNUZ
// #include <c10/util/Float8_e4m3fnuz.h>
// using FP8_TYPE = c10::Float8_e4m3fnuz;
// constexpr auto FP8_E4M3_MAX = 224.0f;
// #else
// #if HIP_FP8_TYPE_E4M3
// #include <c10/util/Float8_e4m3fn.h>
// using FP8_TYPE = c10::Float8_e4m3fn;
// C10_HOST_DEVICE constexpr auto FP8_E4M3_MAX = std::numeric_limits<FP8_TYPE>::max();
// #else
// #error "fp8 is not supported in this processor (arch < gfx942)."
// #endif // HIP_FP8_TYPE_E4M3
// #endif // HIP_FP8_TYPE_FNUZ
// #endif // USE_ROCM
#define FULL_MASK 0xffffffff
......
......@@ -13,6 +13,26 @@ _IMPORT_ERROR = ImportError(
"Failed to load sgl_kernel.flashmla_ops extension. Ensure CUDA Driver >= 12.4"
)
def dcu_create_flashmla_kv_indices(
req_to_token_ptr,
req_pool_indices_ptr,
page_kernel_lens_ptr,
kv_start_idx,
kv_indices_ptr,
req_to_token_ptr_stride,
kv_indices_ptr_stride,
PAGED_SIZE = 64,
):
torch.ops.sgl_kernel.dcu_create_flashmla_kv_indices(req_to_token_ptr,
req_pool_indices_ptr,
page_kernel_lens_ptr,
kv_start_idx,
kv_indices_ptr,
req_to_token_ptr_stride,
kv_indices_ptr_stride,
PAGED_SIZE,
)
def get_mla_metadata(
cache_seqlens: torch.Tensor,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment