Commit 4ddb5447 authored by zhuwenwen's avatar zhuwenwen
Browse files

update cutlass fa and pa

parent fdda4d82
......@@ -24,6 +24,12 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
const std::string& kv_cache_dtype,
torch::Tensor& k_scale, torch::Tensor& v_scale);
void reshape_and_cache_cuda(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
torch::Tensor& k_scale, torch::Tensor& v_scale);
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
......
......@@ -270,6 +270,66 @@ __global__ void reshape_and_cache_kernel(
}
}
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_kernel_cuda(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, block_size, head_size] target layout
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int block_size, int x,
const float* k_scale, const float* v_scale) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) {
// Padding token that should be ignored.
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int64_t src_key_idx = token_idx * key_stride + i;
const int64_t src_value_idx = token_idx * value_stride + i;
const int head_idx = i / head_size;
const int head_offset = i % head_size;
// ---------- calculate target index ----------
// K: [num_blocks, num_heads, block_size, head_size]
const int64_t tgt_key_idx =
block_idx * num_heads * block_size * head_size +
head_idx * block_size * head_size + block_offset * head_size +
head_offset;
// V: [num_blocks, num_heads, head_size, block_size]
const int64_t tgt_value_idx =
block_idx * num_heads * head_size * block_size +
head_idx * head_size * block_size + head_offset * block_size +
block_offset;
scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx];
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
key_cache[tgt_key_idx] = tgt_key;
value_cache[tgt_value_idx] = tgt_value;
} else if constexpr (kv_dt == Fp8KVCacheDataType::kInt8) {
key_cache[tgt_key_idx] =
int8::scaled_vec_conversion_int8<cache_t, scalar_t>(tgt_key,
*k_scale);
value_cache[tgt_value_idx] =
int8::scaled_vec_conversion_int8<cache_t, scalar_t>(tgt_value,
*v_scale);
} else {
key_cache[tgt_key_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
value_cache[tgt_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
}
}
}
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_flash_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
......@@ -538,6 +598,56 @@ void reshape_and_cache(
CALL_RESHAPE_AND_CACHE)
}
#define CALL_RESHAPE_AND_CACHE_CUDA(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_kernel_cuda<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
num_heads, head_size, block_size, 1, \
reinterpret_cast<const float*>(k_scale.data_ptr()), \
reinterpret_cast<const float*>(v_scale.data_ptr()));
void reshape_and_cache_cuda(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, block_size, head_size]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
TORCH_CHECK(key.dim() == 3 && value.dim() == 3,
"key/value must be [num_tokens, num_heads, head_size]");
TORCH_CHECK(key_cache.dim() == 4 && value_cache.dim() == 4,
"cache tensor shape mismatch");
TORCH_CHECK(key_cache.size(0) == value_cache.size(0) &&
key_cache.size(1) == value_cache.size(1) &&
key_cache.size(2) == value_cache.size(3) &&
key_cache.size(3) == value_cache.size(2),
"key/value cache dimension mismatch");
int num_tokens = slot_mapping.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(2); // k layout: [num_blocks, num_heads, block_size, head_size]
int key_stride = key.stride(0);
int value_stride = value.stride(0);
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
CALL_RESHAPE_AND_CACHE_CUDA);
}
// KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
......
......@@ -845,6 +845,15 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" Tensor k_scale, Tensor v_scale) -> ()");
cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
// Reshape the key(new) and value tensors and cache them.
cache_ops.def(
"reshape_and_cache_cuda(Tensor key, Tensor value, "
"Tensor! key_cache, Tensor! value_cache, Tensor slot_mapping, "
"str kv_cache_dtype, Tensor k_scale, Tensor v_scale) -> ()");
cache_ops.impl("reshape_and_cache_cuda",
torch::kCUDA,
&reshape_and_cache_cuda);
// Reshape the key and value tensors and cache them.
cache_ops.def(
"reshape_and_cache_flash(Tensor key, Tensor value,"
......
......@@ -9,6 +9,7 @@ prompts = [
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"Hello, my name is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=16)
......@@ -16,7 +17,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=16)
def main():
# Create an LLM.
llm = LLM(model="facebook/opt-125m",tensor_parallel_size=1, dtype="float16",trust_remote_code=True, enforce_eager=True)
llm = LLM(model="facebook/opt-125m",tensor_parallel_size=1, dtype="float16",trust_remote_code=True, enforce_eager=True, block_size=16, enable_prefix_caching=False)
# Generate texts from the prompts.
# The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
......
......@@ -2070,6 +2070,21 @@ def reshape_and_cache(
kv_cache_dtype, k_scale, v_scale)
def reshape_and_cache_cuda(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> None:
torch.ops._C_cache_ops.reshape_and_cache_cuda(key, value, key_cache,
value_cache, slot_mapping,
kv_cache_dtype, k_scale, v_scale)
def reshape_and_cache_flash(
key: torch.Tensor,
value: torch.Tensor,
......
......@@ -580,17 +580,17 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else:
if SUPPORT_TC:
try:
from flash_attn import flash_attn_varlen_func # noqa: F401
from flash_attn import flash_attn_varlen_func, vllm_flash_attn_varlen_func # , vllm_flash_attn_with_kvcache # noqa: F401
self.fa_attn_func = flash_attn_varlen_func
if not envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN and gpuname.startswith('K100_AI'):
from flash_attn import vllm_flash_attn_varlen_func
self.fa_prefix_attn_func = vllm_flash_attn_varlen_func
self.fa_prefix_attn_func = vllm_flash_attn_varlen_func
# self.fa_decode_attn_func = vllm_flash_attn_with_kvcache
logger.debug("Using CUTLASS FA in ROCmBackend")
except ModuleNotFoundError:
self.use_naive_attn = True
else:
self.use_naive_attn = True
envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN = True
if self.use_naive_attn:
if logits_soft_cap is not None:
......@@ -857,7 +857,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else:
# prefix-enabled attention -
# not applicable for encoder-only models
if envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN or (not gpuname.startswith('K100_AI')):
if envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN:
version_key = triton_key()
if self.attn_type != AttentionType.ENCODER_ONLY:
output[:num_prefill_tokens] = paged_attn.forward_prefix(
......@@ -889,19 +889,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
triton: [GPU blocks, num_kv_heads, head_size // x, block_size, x] --->
cutlass: num_blocks x page_block_size x num_heads_k x head_size i
'''
num_blocks, num_kv_heads, head_size_div_x, block_size, x = key_cache.shape
head_size = head_size_div_x * x
key_cache_flash = key_cache.permute(0, 3, 1, 2, 4) # [num_blocks, block_size, num_kv_heads, head_size//x, x]
key_cache_flash = key_cache_flash.reshape(num_blocks, block_size, num_kv_heads, head_size)
# value_cache
value_cache_flash = value_cache.permute(0, 3, 1, 2) # [num_blocks, block_size, num_kv_heads, head_size]
output[:num_prefill_tokens] = self.fa_prefix_attn_func( # noqa
q=query,
k=key_cache_flash,
v=value_cache_flash,
k=key_cache,
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
seqused_k=prefill_meta.seq_lens_tensor,
......@@ -977,28 +968,48 @@ class ROCmFlashAttentionImpl(AttentionImpl):
)
else:
tree_attention_masks_tensor = decode_meta.tree_attention_masks_tensor
output[num_prefill_tokens:] = paged_attn.forward_decode(
decode_query,
key_cache,
value_cache,
decode_meta.block_tables
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.cross_block_tables,
decode_meta.seq_lens_tensor
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.encoder_seq_lens_tensor,
decode_meta.max_decode_seq_len
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.max_encoder_seq_len,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
layer._k_scale,
layer._v_scale,
attn_masks=tree_attention_masks_tensor,
attn_masks_stride=tree_attention_masks_tensor.stride(0) if tree_attention_masks_tensor is not None else 0
)
if envs.VLLM_USE_FLASH_ATTN_BACKEND:
from flash_attn import vllm_flash_attn_with_kvcache
# output[num_prefill_tokens:] = self.fa_decode_attn_func(
output[num_prefill_tokens:] = vllm_flash_attn_with_kvcache(
q=decode_query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
cache_seqlens=decode_meta.seq_lens_tensor,
block_table=decode_meta.block_tables,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
softcap=self.logits_soft_cap,
alibi_slopes=self.alibi_slopes,
return_softmax_lse=False,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
kv_cache_dtype=self.kv_cache_dtype,
).squeeze(1)
else:
output[num_prefill_tokens:] = paged_attn.forward_decode(
decode_query,
key_cache,
value_cache,
decode_meta.block_tables
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.cross_block_tables,
decode_meta.seq_lens_tensor
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.encoder_seq_lens_tensor,
decode_meta.max_decode_seq_len
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.max_encoder_seq_len,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
layer._k_scale,
layer._v_scale,
attn_masks=tree_attention_masks_tensor,
attn_masks_stride=tree_attention_masks_tensor.stride(0) if tree_attention_masks_tensor is not None else 0
)
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
......
......@@ -58,12 +58,23 @@ class PagedAttention:
) -> Tuple[torch.Tensor, torch.Tensor]:
x = 16 // kv_cache.element_size()
num_blocks = kv_cache.shape[1]
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
-1, x)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
'''
CUTLASS key_cache layout: [num_blocks, num_kv_heads, block_size, head_size]
Triton key_cache layout: [num_blocks, num_kv_heads, head_size // x, block_size, x]
value_cache layout: [num_blocks, num_kv_heads, head_size, block_size]
'''
if envs.VLLM_USE_FLASH_ATTN_BACKEND:
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size)
value_cache = kv_cache[1]
value_cache=value_cache.view(num_blocks, num_kv_heads,head_size, -1)
else:
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
-1, x)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
return key_cache, value_cache
@staticmethod
......@@ -77,16 +88,29 @@ class PagedAttention:
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> None:
ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
)
if envs.VLLM_USE_FLASH_ATTN_BACKEND:
ops.reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
)
else:
ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
)
@staticmethod
def forward_decode(
......
......@@ -150,6 +150,7 @@ if TYPE_CHECKING:
VLLM_TBO_DECODE_BS: int = 0
VLLM_ZERO_OVERHEAD: bool = False
VLLM_ENABLE_MOE_FUSED_GATE: bool = False
VLLM_USE_FLASH_ATTN_BACKEND: bool = False
def get_default_cache_root():
......@@ -990,6 +991,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# If set, vLLM will enable the moe_fused_gate kernel.
"VLLM_ENABLE_MOE_FUSED_GATE":
lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_FUSED_GATE", "1"))),
# vLLM will use FlashAttention Backend for attention computation on rocm
"VLLM_USE_FLASH_ATTN_BACKEND":
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_BACKEND", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment