Commit c5877810 authored by zhangshao's avatar zhangshao
Browse files

增加bw支持

parent a715dfbe
...@@ -567,7 +567,7 @@ __global__ void paged_attention_v1_kernel_TC( ...@@ -567,7 +567,7 @@ __global__ void paged_attention_v1_kernel_TC(
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) { const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
#ifdef __gfx928__ #if defined(__gfx936__) || defined(__gfx928__)
paged_attention_kernel_TC<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel_TC<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE,REUSE_KV_TIMES,use_vmac>( KV_DTYPE, IS_BLOCK_SPARSE,REUSE_KV_TIMES,use_vmac>(
/* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache, /* exp_sums */ nullptr, /* max_logits */ nullptr, out, q, k_cache,
...@@ -607,7 +607,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC( ...@@ -607,7 +607,7 @@ __global__ __launch_bounds__(256, 1) void paged_attention_v2_kernel_TC(
const int blocksparse_local_blocks, const int blocksparse_vert_stride, const int blocksparse_local_blocks, const int blocksparse_vert_stride,
const int blocksparse_block_size, const int blocksparse_head_sliding_step, const int blocksparse_block_size, const int blocksparse_head_sliding_step,
const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) { const int* __restrict__ attn_masks=nullptr, const int attn_masks_stride=0) {
#ifdef __gfx928__ #if defined(__gfx936__) || defined(__gfx928__)
paged_attention_kernel_TC<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, paged_attention_kernel_TC<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS,
KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES,use_vmac, KV_DTYPE, IS_BLOCK_SPARSE, REUSE_KV_TIMES,use_vmac,
PARTITION_SIZE>( PARTITION_SIZE>(
...@@ -952,7 +952,7 @@ void paged_attention_v1_opt_tc( ...@@ -952,7 +952,7 @@ void paged_attention_v1_opt_tc(
const int64_t attn_masks_stride) { const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse|| if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||get_device_name()!="gfx928"){ block_size!=16||query.size(2)!=128||(get_device_name()!="gfx928" && get_device_name()!="gfx936")){
paged_attention_v1_opt(out,query,key_cache,value_cache,num_kv_heads, paged_attention_v1_opt(out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype, scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride, k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
...@@ -1182,7 +1182,7 @@ void paged_attention_v2_opt_tc( ...@@ -1182,7 +1182,7 @@ void paged_attention_v2_opt_tc(
const int64_t attn_masks_stride) { const int64_t attn_masks_stride) {
const bool is_block_sparse = (blocksparse_vert_stride > 1); const bool is_block_sparse = (blocksparse_vert_stride > 1);
if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse|| if(kv_cache_dtype != "auto"||query.dtype() == at::ScalarType::Float||is_block_sparse||
block_size!=16||query.size(2)!=128||get_device_name()!="gfx928"){ block_size!=16||query.size(2)!=128||(get_device_name()!="gfx928" && get_device_name()!="gfx936")){
paged_attention_v2_opt(out,exp_sums,max_logits,tmp_out,query,key_cache,value_cache,num_kv_heads, paged_attention_v2_opt(out,exp_sums,max_logits,tmp_out,query,key_cache,value_cache,num_kv_heads,
scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype, scale,block_tables,seq_lens,block_size,max_seq_len,alibi_slopes,kv_cache_dtype,
k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride, k_scale,v_scale,tp_rank,blocksparse_local_blocks,blocksparse_vert_stride,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment