Unverified Commit 42c87045 authored by fzyzcjy's avatar fzyzcjy Committed by GitHub
Browse files

Add PDL support for quant kernel and rope kernel (#9106)

parent c9bf3877
...@@ -635,6 +635,8 @@ def _set_envs_and_config(server_args: ServerArgs): ...@@ -635,6 +635,8 @@ def _set_envs_and_config(server_args: ServerArgs):
os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls)) os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4" os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
os.environ["CUDA_MODULE_LOADING"] = "AUTO" os.environ["CUDA_MODULE_LOADING"] = "AUTO"
# flashinfer uses this environment variable for various kernels from MoE to quant kernels
os.environ["TRTLLM_ENABLE_PDL"] = "1"
# Set prometheus env vars # Set prometheus env vars
if server_args.enable_metrics: if server_args.enable_metrics:
......
...@@ -550,7 +550,6 @@ class ServerArgs: ...@@ -550,7 +550,6 @@ class ServerArgs:
assert ( assert (
self.quantization == "modelopt_fp4" self.quantization == "modelopt_fp4"
), "modelopt_fp4 quantization is required for Flashinfer MOE" ), "modelopt_fp4 quantization is required for Flashinfer MOE"
os.environ["TRTLLM_ENABLE_PDL"] = "1"
assert self.ep_size in [ assert self.ep_size in [
1, 1,
self.tp_size, self.tp_size,
......
...@@ -90,7 +90,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { ...@@ -90,7 +90,7 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) {
m.def( m.def(
"apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, " "apply_rope_pos_ids_cos_sin_cache(Tensor q, Tensor k, Tensor! q_rope, Tensor! k_rope, Tensor cos_sin_cache, "
"Tensor pos_ids, bool interleave, int cuda_stream, " "Tensor pos_ids, bool interleave, bool enable_pdl, int cuda_stream, "
"Tensor? v, Tensor!? k_buffer, Tensor!? v_buffer, Tensor? kv_cache_loc) -> ()"); "Tensor? v, Tensor!? k_buffer, Tensor!? v_buffer, Tensor? kv_cache_loc) -> ()");
m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache); m.impl("apply_rope_pos_ids_cos_sin_cache", torch::kCUDA, &apply_rope_pos_ids_cos_sin_cache);
......
...@@ -104,6 +104,10 @@ __global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedHeadParallelismKernel ...@@ -104,6 +104,10 @@ __global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedHeadParallelismKernel
uint32_t by = blockIdx.y; uint32_t by = blockIdx.y;
const uint32_t bdy = blockDim.y; const uint32_t bdy = blockDim.y;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
vec_t<float, vec_size> cos, sin; vec_t<float, vec_size> cos, sin;
if (bx * bdy + ty < nnz) { if (bx * bdy + ty < nnz) {
const uint32_t idx = bx * bdy + ty; const uint32_t idx = bx * bdy + ty;
...@@ -178,6 +182,10 @@ __global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedHeadParallelismKernel ...@@ -178,6 +182,10 @@ __global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedHeadParallelismKernel
} }
} }
} }
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
} }
template < template <
...@@ -220,6 +228,10 @@ __global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel( ...@@ -220,6 +228,10 @@ __global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel(
uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y; uint32_t bx = blockIdx.x, tx = threadIdx.x, ty = threadIdx.y;
const uint32_t bdy = blockDim.y; const uint32_t bdy = blockDim.y;
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.wait;");
#endif
vec_t<float, vec_size> cos, sin; vec_t<float, vec_size> cos, sin;
if (bx * bdy + ty < nnz) { if (bx * bdy + ty < nnz) {
const uint32_t idx = bx * bdy + ty; const uint32_t idx = bx * bdy + ty;
...@@ -296,6 +308,10 @@ __global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel( ...@@ -296,6 +308,10 @@ __global__ void BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel(
} }
} }
} }
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
asm volatile("griddepcontrol.launch_dependents;");
#endif
} }
#define DISPATCH_SAVE_KV_CACHE(save_kv_cache, SAVE_KV_CACHE, ...) \ #define DISPATCH_SAVE_KV_CACHE(save_kv_cache, SAVE_KV_CACHE, ...) \
...@@ -340,12 +356,59 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced( ...@@ -340,12 +356,59 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced(
IdType* kv_cache_loc, IdType* kv_cache_loc,
bool interleave, bool interleave,
bool save_kv_cache, bool save_kv_cache,
bool enable_pdl,
cudaStream_t stream = nullptr) { cudaStream_t stream = nullptr) {
int dev_id = 0; int dev_id = 0;
int num_sms = 0; int num_sms = 0;
FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id)); FLASHINFER_CUDA_CALL(cudaGetDevice(&dev_id));
FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id)); FLASHINFER_CUDA_CALL(cudaDeviceGetAttribute(&num_sms, cudaDevAttrMultiProcessorCount, dev_id));
#define LAUNCH_KERNEL_RAW(kernel_name) \
do { \
cudaLaunchConfig_t config = {}; \
config.gridDim = nblks; \
config.blockDim = nthrs; \
config.dynamicSmemBytes = 0; \
config.stream = stream; \
cudaLaunchAttribute attrs[1] = {}; \
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; \
attrs[0].val.programmaticStreamSerializationAllowed = enable_pdl; \
config.numAttrs = 1; \
config.attrs = attrs; \
\
FLASHINFER_CUDA_CALL(cudaLaunchKernelEx( \
&config, \
kernel_name, \
q, \
k, \
v, \
q_rope, \
k_rope, \
k_buffer, \
v_buffer, \
cos_sin_cache, \
pos_ids, \
nnz, \
num_qo_heads, \
num_kv_heads, \
rotary_dim, \
q_stride_n, \
q_stride_h, \
k_stride_n, \
k_stride_h, \
v_stride_n, \
v_stride_h, \
q_rope_stride_n, \
q_rope_stride_h, \
k_rope_stride_n, \
k_rope_stride_h, \
k_buffer_stride_n, \
k_buffer_stride_h, \
v_buffer_stride_n, \
v_buffer_stride_h, \
kv_cache_loc)); \
} while (0)
DISPATCH_SAVE_KV_CACHE(save_kv_cache, SAVE_KV_CACHE, { DISPATCH_SAVE_KV_CACHE(save_kv_cache, SAVE_KV_CACHE, {
DISPATCH_INTERLEAVE(interleave, INTERLEAVE, { DISPATCH_INTERLEAVE(interleave, INTERLEAVE, {
DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, { DISPATCH_HEAD_DIM(head_dim, HEAD_DIM, {
...@@ -359,35 +422,7 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced( ...@@ -359,35 +422,7 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced(
uint32_t bdy = num_threads / bdx; uint32_t bdy = num_threads / bdx;
// how many blocks needed to process all tokens // how many blocks needed to process all tokens
uint32_t nblks_x = (nnz + bdy - 1) / bdy; uint32_t nblks_x = (nnz + bdy - 1) / bdy;
void* args[] = {
(void*)&q,
(void*)&k,
(void*)&v,
(void*)&q_rope,
(void*)&k_rope,
(void*)&k_buffer,
(void*)&v_buffer,
(void*)&cos_sin_cache,
(void*)&pos_ids,
(void*)&nnz,
(void*)&num_qo_heads,
(void*)&num_kv_heads,
(void*)&rotary_dim,
(void*)&q_stride_n,
(void*)&q_stride_h,
(void*)&k_stride_n,
(void*)&k_stride_h,
(void*)&v_stride_n,
(void*)&v_stride_h,
(void*)&q_rope_stride_n,
(void*)&q_rope_stride_h,
(void*)&k_rope_stride_n,
(void*)&k_rope_stride_h,
(void*)&k_buffer_stride_n,
(void*)&k_buffer_stride_h,
(void*)&v_buffer_stride_n,
(void*)&v_buffer_stride_h,
(void*)&kv_cache_loc};
auto kernel_0 = BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel< auto kernel_0 = BatchQKApplyRotaryPosIdsCosSinCacheEnhancedKernel<
SAVE_KV_CACHE, SAVE_KV_CACHE,
INTERLEAVE, INTERLEAVE,
...@@ -405,7 +440,7 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced( ...@@ -405,7 +440,7 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced(
if ((nnz + bdy - 1) / bdy >= num_ctas_0) { if ((nnz + bdy - 1) / bdy >= num_ctas_0) {
dim3 nblks(nblks_x); dim3 nblks(nblks_x);
dim3 nthrs(bdx, bdy); dim3 nthrs(bdx, bdy);
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel_0, nblks, nthrs, args, 0, stream)); LAUNCH_KERNEL_RAW(kernel_0);
} else { } else {
dim3 nblks(nblks_x, num_qo_heads + num_kv_heads); dim3 nblks(nblks_x, num_qo_heads + num_kv_heads);
dim3 nthrs(bdx, bdy); dim3 nthrs(bdx, bdy);
...@@ -417,11 +452,12 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced( ...@@ -417,11 +452,12 @@ cudaError_t BatchQKApplyRotaryPosIdsCosSinCacheEnhanced(
bdx, bdx,
DType, DType,
IdType>; IdType>;
FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel_1, nblks, nthrs, args, 0, stream)); LAUNCH_KERNEL_RAW(kernel_1);
} }
}); });
}); });
}); });
#undef LAUNCH_KERNEL_RAW
return cudaSuccess; return cudaSuccess;
} }
......
...@@ -27,6 +27,7 @@ void apply_rope_pos_ids_cos_sin_cache( ...@@ -27,6 +27,7 @@ void apply_rope_pos_ids_cos_sin_cache(
at::Tensor cos_sin_cache, at::Tensor cos_sin_cache,
at::Tensor pos_ids, at::Tensor pos_ids,
bool interleave, bool interleave,
bool enable_pdl,
int64_t cuda_stream, int64_t cuda_stream,
const std::optional<at::Tensor>& v, const std::optional<at::Tensor>& v,
const std::optional<at::Tensor>& k_buffer, const std::optional<at::Tensor>& k_buffer,
...@@ -124,12 +125,14 @@ void apply_rope_pos_ids_cos_sin_cache( ...@@ -124,12 +125,14 @@ void apply_rope_pos_ids_cos_sin_cache(
kv_cache_loc_ptr, kv_cache_loc_ptr,
interleave, interleave,
save_kv_cache, save_kv_cache,
enable_pdl,
stream); stream);
TORCH_CHECK( TORCH_CHECK(
status == cudaSuccess, status == cudaSuccess,
"BatchQKApplyRotaryPosIdsCosSinCacheEnhanced failed with error code " + "BatchQKApplyRotaryPosIdsCosSinCacheEnhanced failed with error code " +
std::string(cudaGetErrorString(status))); std::string(cudaGetErrorString(status)));
} else { } else {
TORCH_CHECK(!enable_pdl);
cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache( cudaError_t status = BatchQKApplyRotaryPosIdsCosSinCache(
static_cast<c_type*>(q.data_ptr()), static_cast<c_type*>(q.data_ptr()),
static_cast<c_type*>(k.data_ptr()), static_cast<c_type*>(k.data_ptr()),
......
...@@ -151,6 +151,7 @@ void apply_rope_pos_ids_cos_sin_cache( ...@@ -151,6 +151,7 @@ void apply_rope_pos_ids_cos_sin_cache(
at::Tensor cos_sin_cache, at::Tensor cos_sin_cache,
at::Tensor pos_ids, at::Tensor pos_ids,
bool interleave, bool interleave,
bool enable_pdl,
int64_t cuda_stream, int64_t cuda_stream,
const std::optional<at::Tensor>& v, const std::optional<at::Tensor>& v,
const std::optional<at::Tensor>& k_buffer, const std::optional<at::Tensor>& k_buffer,
......
...@@ -271,6 +271,7 @@ def apply_rope_with_cos_sin_cache_inplace( ...@@ -271,6 +271,7 @@ def apply_rope_with_cos_sin_cache_inplace(
cos_sin_cache: torch.Tensor, cos_sin_cache: torch.Tensor,
is_neox: bool = True, is_neox: bool = True,
fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None, fused_set_kv_buffer_arg: Optional[FusedSetKVBufferArg] = None,
enable_pdl: Optional[bool] = None,
) -> None: ) -> None:
r""" r"""
Apply rotary embedding to keys and queries with precomputed cos/sin values. Apply rotary embedding to keys and queries with precomputed cos/sin values.
...@@ -307,6 +308,10 @@ def apply_rope_with_cos_sin_cache_inplace( ...@@ -307,6 +308,10 @@ def apply_rope_with_cos_sin_cache_inplace(
if cos_sin_cache.dtype != torch.float32: if cos_sin_cache.dtype != torch.float32:
raise ValueError("cos_sin_cache should be float32") raise ValueError("cos_sin_cache should be float32")
if enable_pdl is None:
# the non-fused branch does not yet support PDL, but after we switch to our impl for that branch it will
enable_pdl = is_arch_support_pdl() and (fused_set_kv_buffer_arg is not None)
if (a := fused_set_kv_buffer_arg) is not None: if (a := fused_set_kv_buffer_arg) is not None:
assert a.k_scale is None, "k_scale is not yet supported" assert a.k_scale is None, "k_scale is not yet supported"
assert a.v_scale is None, "v_scale is not yet supported" assert a.v_scale is None, "v_scale is not yet supported"
...@@ -323,6 +328,7 @@ def apply_rope_with_cos_sin_cache_inplace( ...@@ -323,6 +328,7 @@ def apply_rope_with_cos_sin_cache_inplace(
cos_sin_cache, cos_sin_cache,
positions.long(), positions.long(),
(not is_neox), (not is_neox),
enable_pdl,
get_cuda_stream(), get_cuda_stream(),
( (
_view_3d(fused_set_kv_buffer_arg.value) _view_3d(fused_set_kv_buffer_arg.value)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment