Unverified Commit cf8a613a authored by Xin Yang's avatar Xin Yang Committed by GitHub
Browse files

Support only half types for concat_mla_q kernel (#37892)


Signed-off-by: default avatarXin Yang <xyangx@amazon.com>
parent 01acf96c
......@@ -1490,6 +1490,9 @@ void concat_mla_q(torch::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim]
TORCH_CHECK(ql_nope.stride(2) == 1, "ql_nope must have stride 1 in dim 2");
TORCH_CHECK(q_pe.stride(2) == 1, "q_pe must have stride 1 in dim 2");
TORCH_CHECK(q_out.stride(2) == 1, "q_out must have stride 1 in dim 2");
TORCH_CHECK(ql_nope.scalar_type() == at::ScalarType::Half ||
ql_nope.scalar_type() == at::ScalarType::BFloat16,
"ql_nope must be float16 or bfloat16 dtype");
if (num_tokens == 0) return;
......@@ -1501,7 +1504,7 @@ void concat_mla_q(torch::Tensor& ql_nope, // [num_tokens, num_heads, nope_dim]
const at::cuda::OptionalCUDAGuard device_guard(device_of(ql_nope));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(ql_nope.scalar_type(), "concat_mla_q", [&] {
VLLM_DISPATCH_HALF_TYPES(ql_nope.scalar_type(), "concat_mla_q", [&] {
vllm::ConcatMLAQKernel<scalar_t, 512><<<grid_size, block_size, 0, stream>>>(
q_out.data_ptr<scalar_t>(), ql_nope.data_ptr<scalar_t>(),
q_pe.data_ptr<scalar_t>(), num_tokens, num_heads, q_out.stride(0),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment