Unverified Commit 42c87045 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Add PDL support for quant kernel and rope kernel (#9106)

parent c9bf3877
......@@ -635,6 +635,8 @@ def _set_envs_and_config(server_args: ServerArgs):
os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
os.environ["CUDA_MODULE_LOADING"] = "AUTO"
# flashinfer uses this environment variable for various kernels from MoE to quant kernels
os.environ["TRTLLM_ENABLE_PDL"] = "1"
# Set prometheus env vars
if server_args.enable_metrics:
......
......@@ -550,7 +550,6 @@ class ServerArgs:
assert (
self.quantization == "modelopt_fp4"
), "modelopt_fp4 quantization is required for Flashinfer MOE"
os.environ["TRTLLM_ENABLE_PDL"] = "1"
assert self.ep_size in [
1,
self.tp_size,
......
......@@ -90,7 +90,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def(
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
"Tensor pos_ids, bool interleave, int cuda_stream, "
"Tensor pos_ids, bool interleave, bool enable_pdl, int cuda_stream, "
"Tensor? v, Tensor!? k_buffer, Tensor!? v_buffer, Tensor? kv_cache_loc) -> ()");
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
......
......@@ -104,6 +104,10 @@ __global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedHeadParallelismKernel
uint32_t by = blockIdx.y;
const uint32_t bdy = blockDim.y;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
vec_t<float, vec_size> cos, sin;
if (bx * bdy + ty < nnz) {
const uint32_t idx = bx * bdy + ty;
......@@ -178,6 +182,10 @@ __global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedHeadParallelismKernel
}
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
template <
......@@ -220,6 +228,10 @@ __global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel(
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
const uint32_t bdy = blockDim.y;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
vec_t<float, vec_size> cos, sin;
if (bx * bdy + ty < nnz) {
const uint32_t idx = bx * bdy + ty;
......@@ -296,6 +308,10 @@ __global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel(
}
}
}
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
}
#define DISPATCH_SAVE_KV_CACHE(save_kv_cache, SAVE_KV_CACHE, ...) \
......@@ -340,12 +356,59 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced(
IdType* kv_cache_loc,
bool interleave,
bool save_kv_cache,
bool enable_pdl,
cudaStream_t stream = nullptr) {
int dev_id = 0;
int num_sms = 0;
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id));
#define LAUNCH_KERNEL_RAW(kernel_name) \
do { \
cudaLaunchConfig_t config = {}; \
config.gridDim = nblks; \
config.blockDim = nthrs; \
config.dynamicSmemBytes = 0; \
config.stream = stream; \
cudaLaunchAttribute attrs[1] = {}; \
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; \
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; \
config.numAttrs = 1; \
config.attrs = attrs; \
\
FLASHINFER_CUDA_CALL(cudaLaunchKernelEx( \
&config, \
kernel_name, \
q, \
k, \
v, \
q_rope, \
k_rope, \
k_buffer, \
v_buffer, \
cos_sin_cache, \
pos_ids, \
nnz, \
num_qo_heads, \
num_kv_heads, \
rotary_dim, \
q_stride_n, \
q_stride_h, \
k_stride_n, \
k_stride_h, \
v_stride_n, \
v_stride_h, \
q_rope_stride_n, \
q_rope_stride_h, \
k_rope_stride_n, \
k_rope_stride_h, \
k_buffer_stride_n, \
k_buffer_stride_h, \
v_buffer_stride_n, \
v_buffer_stride_h, \
kv_cache_loc)); \
} while (0)
DISPATCH_SAVE_KV_CACHE(save_kv_cache, SAVE_KV_CACHE, {
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
......@@ -359,35 +422,7 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced(
uint32_t bdy = num_threads / bdx;
// how many blocks needed to process all tokens
uint32_t nblks_x = (nnz + bdy - 1) / bdy;
void* args[] = {
(void*)&q,
(void*)&k,
(void*)&v,
(void*)&q_rope,
(void*)&k_rope,
(void*)&k_buffer,
(void*)&v_buffer,
(void*)&cos_sin_cache,
(void*)&pos_ids,
(void*)&nnz,
(void*)&num_qo_heads,
(void*)&num_kv_heads,
(void*)&rotary_dim,
(void*)&q_stride_n,
(void*)&q_stride_h,
(void*)&k_stride_n,
(void*)&k_stride_h,
(void*)&v_stride_n,
(void*)&v_stride_h,
(void*)&q_rope_stride_n,
(void*)&q_rope_stride_h,
(void*)&k_rope_stride_n,
(void*)&k_rope_stride_h,
(void*)&k_buffer_stride_n,
(void*)&k_buffer_stride_h,
(void*)&v_buffer_stride_n,
(void*)&v_buffer_stride_h,
(void*)&kv_cache_loc};
auto kernel_0 = BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel<
SAVE_KV_CACHE,
INTERLEAVE,
......@@ -405,7 +440,7 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced(
if ((nnz + bdy - 1) / bdy >= num_ctas_0) {
dim3 nblks(nblks_x);
dim3 nthrs(bdx, bdy);
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel_0, nblks, nthrs, args, 0, stream));
LAUNCH_KERNEL_RAW(kernel_0);
} else {
dim3 nblks(nblks_x, num_qo_heads + num_kv_heads);
dim3 nthrs(bdx, bdy);
......@@ -417,11 +452,12 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced(
bdx,
DType,
IdType>;
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel_1, nblks, nthrs, args, 0, stream));
LAUNCH_KERNEL_RAW(kernel_1);
}
});
});
});
#undef LAUNCH_KERNEL_RAW
return cudaSuccess;
}
......
......@@ -27,6 +27,7 @@ void apply_rope_pos_ids_cos_sin_cache(
at::Tensor cos_sin_cache,
at::Tensor pos_ids,
bool interleave,
bool enable_pdl,
int64_t cuda_stream,
const std::optional<at::Tensor>& v,
const std::optional<at::Tensor>& k_buffer,
......@@ -124,12 +125,14 @@ void apply_rope_pos_ids_cos_sin_cache(
kv_cache_loc_ptr,
interleave,
save_kv_cache,
enable_pdl,
stream);
TORCH_CHECK(
status == cudaSuccess,
"BatchQKApplyRotaryPosIdsCosSinCacheEnhanced failed with error code " +
std::string(cudaGetErrorString(status)));
} else {
TORCH_CHECK(!enable_pdl);
cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache(
static_cast<c_type*>(q.data_ptr()),
static_cast<c_type*>(k.data_ptr()),
......
......@@ -151,6 +151,7 @@ void apply_rope_pos_ids_cos_sin_cache(
at::Tensor cos_sin_cache,
at::Tensor pos_ids,
bool interleave,
bool enable_pdl,
int64_t cuda_stream,
const std::optional<at::Tensor>& v,
const std::optional<at::Tensor>& k_buffer,
......
......@@ -271,6 +271,7 @@ def apply_rope_with_cos_sin_cache_inplace(
cos_sin_cache: torch.Tensor,
is_neox: bool = True,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
enable_pdl: Optional[bool] = None,
) -> None:
r"""
Apply rotary embedding to keys and queries with precomputed cos/sin values.
......@@ -307,6 +308,10 @@ def apply_rope_with_cos_sin_cache_inplace(
if cos_sin_cache.dtype != torch.float32:
raise ValueError("cos_sin_cache should be float32")
if enable_pdl is None:
# the non-fused branch does not yet support PDL, but after we switch to our impl for that branch it will
enable_pdl = is_arch_support_pdl() and (fused_set_kv_buffer_arg is not None)
if (a := fused_set_kv_buffer_arg) is not None:
assert a.k_scale is None, "k_scale is not yet supported"
assert a.v_scale is None, "v_scale is not yet supported"
......@@ -323,6 +328,7 @@ def apply_rope_with_cos_sin_cache_inplace(
cos_sin_cache,
positions.long(),
(not is_neox),
enable_pdl,
get_cuda_stream(),
(
_view_3d(fused_set_kv_buffer_arg.value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment