Commit 9e03fc67 authored by raojy's avatar raojy
Browse files

nemotron_enable

parent 3b50924c
Pipeline #3457 failed with stages
in 0 seconds
......@@ -80,7 +80,7 @@ class CudaRTLibrary:
),
]
# https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Runtime_API_functions_supported_by_HIP.html # noqa
# https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/hip/hip_runtime_api_functions_supported_by_HIP.html # noqa
cuda_to_hip_mapping = {
"cudaSetDevice": "hipSetDevice",
"cudaDeviceSynchronize": "hipDeviceSynchronize",
......
......@@ -394,6 +394,7 @@ def _chunk_scan_fwd(
if initial_states is not None
else (0, 0, 0, 0)
)
initstates_ptr = initial_states if initial_states is not None else states
_chunk_scan_fwd_kernel[grid](
cb_ptr=cb,
......@@ -406,7 +407,7 @@ def _chunk_scan_fwd(
C_ptr=C,
states_ptr=states,
D_ptr=D,
initstates_ptr=initial_states,
initstates_ptr=initstates_ptr,
cu_chunk_seqlens_ptr=cu_chunk_seqlens,
chunk_size=chunk_size,
hdim=headdim,
......
......@@ -657,6 +657,8 @@ def chunk_state_varlen(
batch,
nheads,
)
initstates_ptr = initial_states if initial_states is not None else chunk_states
with torch.cuda.device(x.device.index):
_chunk_state_varlen_kernel[grid](
x_ptr=x,
......@@ -666,7 +668,7 @@ def chunk_state_varlen(
chunk_states_ptr=chunk_states,
cu_seqlens_ptr=cu_seqlens,
states_ptr=states,
initstates_ptr=initial_states,
initstates_ptr=initstates_ptr,
hdim=headdim,
dstate=dstate,
chunk_size=chunk_size,
......
......@@ -248,10 +248,11 @@ def apply_top_k_top_p(
if p is None and k is None:
return logits
if HAS_TRITON and logits.shape[0] >= 8:
return apply_top_k_top_p_triton(logits, k, p)
# 将这段强制注释掉,不让它走向 Triton 的死亡深渊
# if HAS_TRITON and logits.shape[0] >= 8:
# return apply_top_k_top_p_triton(logits, k, p)
# Use pytorch sort implementation for small batch sizes.
# 强制让它无论 batch 多少都走安全的 PyTorch 原生算子
return apply_top_k_top_p_pytorch(logits, k, p)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment