Commit 4ddb5447 authored by zhuwenwen's avatar zhuwenwen
Browse files

update cutlass fa and pa

parent fdda4d82
...@@ -24,6 +24,12 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, ...@@ -24,6 +24,12 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
const std::string& kv_cache_dtype, const std::string& kv_cache_dtype,
torch::Tensor& k_scale, torch::Tensor& v_scale); torch::Tensor& k_scale, torch::Tensor& v_scale);
void reshape_and_cache_cuda(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
torch::Tensor& k_scale, torch::Tensor& v_scale);
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& key_cache,
torch::Tensor& value_cache, torch::Tensor& value_cache,
......
...@@ -270,6 +270,66 @@ __global__ void reshape_and_cache_kernel( ...@@ -270,6 +270,66 @@ __global__ void reshape_and_cache_kernel(
} }
} }
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_kernel_cuda(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, block_size, head_size] target layout
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int block_size, int x,
const float* k_scale, const float* v_scale) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) {
// Padding token that should be ignored.
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int64_t src_key_idx = token_idx * key_stride + i;
const int64_t src_value_idx = token_idx * value_stride + i;
const int head_idx = i / head_size;
const int head_offset = i % head_size;
// ---------- calculate target index ----------
// K: [num_blocks, num_heads, block_size, head_size]
const int64_t tgt_key_idx =
block_idx * num_heads * block_size * head_size +
head_idx * block_size * head_size + block_offset * head_size +
head_offset;
// V: [num_blocks, num_heads, head_size, block_size]
const int64_t tgt_value_idx =
block_idx * num_heads * head_size * block_size +
head_idx * head_size * block_size + head_offset * block_size +
block_offset;
scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx];
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
key_cache[tgt_key_idx] = tgt_key;
value_cache[tgt_value_idx] = tgt_value;
} else if constexpr (kv_dt == Fp8KVCacheDataType::kInt8) {
key_cache[tgt_key_idx] =
int8::scaled_vec_conversion_int8<cache_t, scalar_t>(tgt_key,
*k_scale);
value_cache[tgt_value_idx] =
int8::scaled_vec_conversion_int8<cache_t, scalar_t>(tgt_value,
*v_scale);
} else {
key_cache[tgt_key_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
value_cache[tgt_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
}
}
}
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt> template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_flash_kernel( __global__ void reshape_and_cache_flash_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
...@@ -538,6 +598,56 @@ void reshape_and_cache( ...@@ -538,6 +598,56 @@ void reshape_and_cache(
CALL_RESHAPE_AND_CACHE) CALL_RESHAPE_AND_CACHE)
} }
#define CALL_RESHAPE_AND_CACHE_CUDA(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_kernel_cuda<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
num_heads, head_size, block_size, 1, \
reinterpret_cast<const float*>(k_scale.data_ptr()), \
reinterpret_cast<const float*>(v_scale.data_ptr()));
void reshape_and_cache_cuda(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, block_size, head_size]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
TORCH_CHECK(key.dim() == 3 && value.dim() == 3,
"key/value must be [num_tokens, num_heads, head_size]");
TORCH_CHECK(key_cache.dim() == 4 && value_cache.dim() == 4,
"cache tensor shape mismatch");
TORCH_CHECK(key_cache.size(0) == value_cache.size(0) &&
key_cache.size(1) == value_cache.size(1) &&
key_cache.size(2) == value_cache.size(3) &&
key_cache.size(3) == value_cache.size(2),
"key/value cache dimension mismatch");
int num_tokens = slot_mapping.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(2); // k layout: [num_blocks, num_heads, block_size, head_size]
int key_stride = key.stride(0);
int value_stride = value.stride(0);
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
CALL_RESHAPE_AND_CACHE_CUDA);
}
// KV_T is the data type of key and value tensors. // KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache. // CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache. // KV_DTYPE is the real data type of kv-cache.
......
...@@ -845,6 +845,15 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { ...@@ -845,6 +845,15 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" Tensor k_scale, Tensor v_scale) -> ()"); " Tensor k_scale, Tensor v_scale) -> ()");
cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache); cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
// Reshape the key(new) and value tensors and cache them.
cache_ops.def(
"reshape_and_cache_cuda(Tensor key, Tensor value, "
"Tensor! key_cache, Tensor! value_cache, Tensor slot_mapping, "
"str kv_cache_dtype, Tensor k_scale, Tensor v_scale) -> ()");
cache_ops.impl("reshape_and_cache_cuda",
torch::kCUDA,
&reshape_and_cache_cuda);
// Reshape the key and value tensors and cache them. // Reshape the key and value tensors and cache them.
cache_ops.def( cache_ops.def(
"reshape_and_cache_flash(Tensor key, Tensor value," "reshape_and_cache_flash(Tensor key, Tensor value,"
......
...@@ -9,6 +9,7 @@ prompts = [ ...@@ -9,6 +9,7 @@ prompts = [
"The president of the United States is", "The president of the United States is",
"The capital of France is", "The capital of France is",
"The future of AI is", "The future of AI is",
"Hello, my name is",
] ]
# Create a sampling params object. # Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=16) sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=16)
...@@ -16,7 +17,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=16) ...@@ -16,7 +17,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=16)
def main(): def main():
# Create an LLM. # Create an LLM.
llm = LLM(model="facebook/opt-125m",tensor_parallel_size=1, dtype="float16",trust_remote_code=True, enforce_eager=True) llm = LLM(model="facebook/opt-125m",tensor_parallel_size=1, dtype="float16",trust_remote_code=True, enforce_eager=True, block_size=16, enable_prefix_caching=False)
# Generate texts from the prompts. # Generate texts from the prompts.
# The output is a list of RequestOutput objects # The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information. # that contain the prompt, generated text, and other information.
......
...@@ -2070,6 +2070,21 @@ def reshape_and_cache( ...@@ -2070,6 +2070,21 @@ def reshape_and_cache(
kv_cache_dtype, k_scale, v_scale) kv_cache_dtype, k_scale, v_scale)
def reshape_and_cache_cuda(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> None:
torch.ops._C_cache_ops.reshape_and_cache_cuda(key, value, key_cache,
value_cache, slot_mapping,
kv_cache_dtype, k_scale, v_scale)
def reshape_and_cache_flash( def reshape_and_cache_flash(
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
......
...@@ -580,17 +580,17 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -580,17 +580,17 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else: else:
if SUPPORT_TC: if SUPPORT_TC:
try: try:
from flash_attn import flash_attn_varlen_func # noqa: F401 from flash_attn import flash_attn_varlen_func, vllm_flash_attn_varlen_func # , vllm_flash_attn_with_kvcache # noqa: F401
self.fa_attn_func = flash_attn_varlen_func self.fa_attn_func = flash_attn_varlen_func
if not envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN and gpuname.startswith('K100_AI'): self.fa_prefix_attn_func = vllm_flash_attn_varlen_func
from flash_attn import vllm_flash_attn_varlen_func # self.fa_decode_attn_func = vllm_flash_attn_with_kvcache
self.fa_prefix_attn_func = vllm_flash_attn_varlen_func
logger.debug("Using CUTLASS FA in ROCmBackend") logger.debug("Using CUTLASS FA in ROCmBackend")
except ModuleNotFoundError: except ModuleNotFoundError:
self.use_naive_attn = True self.use_naive_attn = True
else: else:
self.use_naive_attn = True self.use_naive_attn = True
envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN = True
if self.use_naive_attn: if self.use_naive_attn:
if logits_soft_cap is not None: if logits_soft_cap is not None:
...@@ -857,7 +857,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -857,7 +857,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else: else:
# prefix-enabled attention - # prefix-enabled attention -
# not applicable for encoder-only models # not applicable for encoder-only models
if envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN or (not gpuname.startswith('K100_AI')): if envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN:
version_key = triton_key() version_key = triton_key()
if self.attn_type != AttentionType.ENCODER_ONLY: if self.attn_type != AttentionType.ENCODER_ONLY:
output[:num_prefill_tokens] = paged_attn.forward_prefix( output[:num_prefill_tokens] = paged_attn.forward_prefix(
...@@ -889,19 +889,10 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -889,19 +889,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
triton: [GPU blocks, num_kv_heads, head_size // x, block_size, x] ---> triton: [GPU blocks, num_kv_heads, head_size // x, block_size, x] --->
cutlass: num_blocks x page_block_size x num_heads_k x head_size i cutlass: num_blocks x page_block_size x num_heads_k x head_size i
''' '''
num_blocks, num_kv_heads, head_size_div_x, block_size, x = key_cache.shape
head_size = head_size_div_x * x
key_cache_flash = key_cache.permute(0, 3, 1, 2, 4) # [num_blocks, block_size, num_kv_heads, head_size//x, x]
key_cache_flash = key_cache_flash.reshape(num_blocks, block_size, num_kv_heads, head_size)
# value_cache
value_cache_flash = value_cache.permute(0, 3, 1, 2) # [num_blocks, block_size, num_kv_heads, head_size]
output[:num_prefill_tokens] = self.fa_prefix_attn_func( # noqa output[:num_prefill_tokens] = self.fa_prefix_attn_func( # noqa
q=query, q=query,
k=key_cache_flash, k=key_cache,
v=value_cache_flash, v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc, cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len, max_seqlen_q=prefill_meta.max_query_len,
seqused_k=prefill_meta.seq_lens_tensor, seqused_k=prefill_meta.seq_lens_tensor,
...@@ -977,28 +968,48 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -977,28 +968,48 @@ class ROCmFlashAttentionImpl(AttentionImpl):
) )
else: else:
tree_attention_masks_tensor = decode_meta.tree_attention_masks_tensor tree_attention_masks_tensor = decode_meta.tree_attention_masks_tensor
output[num_prefill_tokens:] = paged_attn.forward_decode( if envs.VLLM_USE_FLASH_ATTN_BACKEND:
decode_query, from flash_attn import vllm_flash_attn_with_kvcache
key_cache, # output[num_prefill_tokens:] = self.fa_decode_attn_func(
value_cache, output[num_prefill_tokens:] = vllm_flash_attn_with_kvcache(
decode_meta.block_tables q=decode_query.unsqueeze(1),
if self.attn_type != AttentionType.ENCODER_DECODER else k_cache=key_cache,
decode_meta.cross_block_tables, v_cache=value_cache,
decode_meta.seq_lens_tensor cache_seqlens=decode_meta.seq_lens_tensor,
if self.attn_type != AttentionType.ENCODER_DECODER else block_table=decode_meta.block_tables,
decode_meta.encoder_seq_lens_tensor, softmax_scale=self.scale,
decode_meta.max_decode_seq_len causal=True,
if self.attn_type != AttentionType.ENCODER_DECODER else window_size=self.sliding_window,
decode_meta.max_encoder_seq_len, softcap=self.logits_soft_cap,
self.kv_cache_dtype, alibi_slopes=self.alibi_slopes,
self.num_kv_heads, return_softmax_lse=False,
self.scale, k_scale=layer._k_scale,
self.alibi_slopes, v_scale=layer._v_scale,
layer._k_scale, kv_cache_dtype=self.kv_cache_dtype,
layer._v_scale, ).squeeze(1)
attn_masks=tree_attention_masks_tensor, else:
attn_masks_stride=tree_attention_masks_tensor.stride(0) if tree_attention_masks_tensor is not None else 0 output[num_prefill_tokens:] = paged_attn.forward_decode(
) decode_query,
key_cache,
value_cache,
decode_meta.block_tables
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.cross_block_tables,
decode_meta.seq_lens_tensor
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.encoder_seq_lens_tensor,
decode_meta.max_decode_seq_len
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.max_encoder_seq_len,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
layer._k_scale,
layer._v_scale,
attn_masks=tree_attention_masks_tensor,
attn_masks_stride=tree_attention_masks_tensor.stride(0) if tree_attention_masks_tensor is not None else 0
)
# Reshape the output tensor. # Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size) return output.view(-1, self.num_heads * self.head_size)
......
...@@ -58,12 +58,23 @@ class PagedAttention: ...@@ -58,12 +58,23 @@ class PagedAttention:
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
x = 16 // kv_cache.element_size() x = 16 // kv_cache.element_size()
num_blocks = kv_cache.shape[1] num_blocks = kv_cache.shape[1]
key_cache = kv_cache[0] '''
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, CUTLASS key_cache layout: [num_blocks, num_kv_heads, block_size, head_size]
-1, x) Triton key_cache layout: [num_blocks, num_kv_heads, head_size // x, block_size, x]
value_cache = kv_cache[1] value_cache layout: [num_blocks, num_kv_heads, head_size, block_size]
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) '''
if envs.VLLM_USE_FLASH_ATTN_BACKEND:
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size)
value_cache = kv_cache[1]
value_cache=value_cache.view(num_blocks, num_kv_heads,head_size, -1)
else:
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
-1, x)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
return key_cache, value_cache return key_cache, value_cache
@staticmethod @staticmethod
...@@ -77,16 +88,29 @@ class PagedAttention: ...@@ -77,16 +88,29 @@ class PagedAttention:
k_scale: torch.Tensor, k_scale: torch.Tensor,
v_scale: torch.Tensor, v_scale: torch.Tensor,
) -> None: ) -> None:
ops.reshape_and_cache( if envs.VLLM_USE_FLASH_ATTN_BACKEND:
key, ops.reshape_and_cache_cuda(
value, key,
key_cache, value,
value_cache, key_cache,
slot_mapping.flatten(), value_cache,
kv_cache_dtype, slot_mapping.flatten(),
k_scale, kv_cache_dtype,
v_scale, k_scale,
) v_scale,
)
else:
ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
)
@staticmethod @staticmethod
def forward_decode( def forward_decode(
......
...@@ -150,6 +150,7 @@ if TYPE_CHECKING: ...@@ -150,6 +150,7 @@ if TYPE_CHECKING:
VLLM_TBO_DECODE_BS: int = 0 VLLM_TBO_DECODE_BS: int = 0
VLLM_ZERO_OVERHEAD: bool = False VLLM_ZERO_OVERHEAD: bool = False
VLLM_ENABLE_MOE_FUSED_GATE: bool = False VLLM_ENABLE_MOE_FUSED_GATE: bool = False
VLLM_USE_FLASH_ATTN_BACKEND: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -990,6 +991,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -990,6 +991,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# If set, vLLM will enable the moe_fused_gate kernel. # If set, vLLM will enable the moe_fused_gate kernel.
"VLLM_ENABLE_MOE_FUSED_GATE": "VLLM_ENABLE_MOE_FUSED_GATE":
lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_FUSED_GATE", "1"))), lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_FUSED_GATE", "1"))),
# vLLM will use FlashAttention Backend for attention computation on rocm
"VLLM_USE_FLASH_ATTN_BACKEND":
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_BACKEND", "False").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment