Commit 9e27b5e4 authored by 王敏's avatar 王敏
Browse files

Merge remote-tracking branch 'origin/v0.9.1-dev' into v0.9.1-dev

parents 504c262e b2fa85ce
...@@ -965,19 +965,15 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO ...@@ -965,19 +965,15 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO
) )
max_num_partitions=1; max_num_partitions=1;
int blocks=max_num_partitions*batchsize*qheads; int blocks=max_num_partitions*batchsize*qheads;
if(device_name=="gfx928"){ if(batchsize>100&&max_seq_len>=2000){
if(batchsize*qheads>1024&&max_seq_len>=2000){ if(max_seq_len<3900)reusekv=4;
max_num_partitions=1;
if(max_seq_len<2000)reusekv=8;
else if(max_seq_len<3900)reusekv=4;
else{ else{
PARTITION_SIZE=2048; PARTITION_SIZE=1024;
reusekv=8; reusekv=8;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
} }
return; return;
} }
}
if(max_num_partitions==1){ if(max_num_partitions==1){
if(max_seq_len<512){ if(max_seq_len<512){
int bytes=max_seq_len*qheads*batchsize; int bytes=max_seq_len*qheads*batchsize;
...@@ -995,9 +991,7 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO ...@@ -995,9 +991,7 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO
if(blocks<600||qheads<=kvheads*4){reusekv=4;return;} if(blocks<600||qheads<=kvheads*4){reusekv=4;return;}
reusekv=8;return; reusekv=8;return;
} }
if(device_name=="gfx928"){ if(batchsize>100&&max_seq_len>=2000){
if(batchsize*qheads>1024&&max_seq_len>=2000){
max_num_partitions=1;
if(max_seq_len<3900)reusekv=4; if(max_seq_len<3900)reusekv=4;
else{ else{
PARTITION_SIZE=2048; PARTITION_SIZE=2048;
...@@ -1006,7 +1000,6 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO ...@@ -1006,7 +1000,6 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO
} }
return; return;
} }
}
if(max_seq_len<=1000|| if(max_seq_len<=1000||
max_seq_len<=1500&&(qheads>4&&batchsize>=16||batchsize>=64)) max_seq_len<=1500&&(qheads>4&&batchsize>=16||batchsize>=64))
max_num_partitions=1; max_num_partitions=1;
...@@ -1068,6 +1061,7 @@ void paged_attention_v2_launcher_opt_tc( ...@@ -1068,6 +1061,7 @@ void paged_attention_v2_launcher_opt_tc(
int reusekv, num_thread,max_num_partitions,PARTITION_SIZE=512; int reusekv, num_thread,max_num_partitions,PARTITION_SIZE=512;
if(!is_half&&max_seq_len<=8192)PARTITION_SIZE=256; if(!is_half&&max_seq_len<=8192)PARTITION_SIZE=256;
get_numberthread_and_reuse_kv_v2(num_thread,reusekv,PARTITION_SIZE,max_num_partitions,num_seqs,max_seq_len,num_heads,num_kv_heads,num_blocks); get_numberthread_and_reuse_kv_v2(num_thread,reusekv,PARTITION_SIZE,max_num_partitions,num_seqs,max_seq_len,num_heads,num_kv_heads,num_blocks);
if(num_seqs>100&&max_num_partitions>16)max_num_partitions=16;
if(PA_PARTITION_SIZE!=0){ if(PA_PARTITION_SIZE!=0){
PARTITION_SIZE=PA_PARTITION_SIZE; PARTITION_SIZE=PA_PARTITION_SIZE;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE); max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
......
...@@ -24,6 +24,12 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value, ...@@ -24,6 +24,12 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
const std::string& kv_cache_dtype, const std::string& kv_cache_dtype,
torch::Tensor& k_scale, torch::Tensor& v_scale); torch::Tensor& k_scale, torch::Tensor& v_scale);
void reshape_and_cache_cuda(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
torch::Tensor& k_scale, torch::Tensor& v_scale);
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value, void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& key_cache,
torch::Tensor& value_cache, torch::Tensor& value_cache,
......
...@@ -270,6 +270,66 @@ __global__ void reshape_and_cache_kernel( ...@@ -270,6 +270,66 @@ __global__ void reshape_and_cache_kernel(
} }
} }
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_kernel_cuda(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, block_size, head_size] target layout
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int block_size, int x,
const float* k_scale, const float* v_scale) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) {
// Padding token that should be ignored.
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int64_t src_key_idx = token_idx * key_stride + i;
const int64_t src_value_idx = token_idx * value_stride + i;
const int head_idx = i / head_size;
const int head_offset = i % head_size;
// ---------- calculate target index ----------
// K: [num_blocks, num_heads, block_size, head_size]
const int64_t tgt_key_idx =
block_idx * num_heads * block_size * head_size +
head_idx * block_size * head_size + block_offset * head_size +
head_offset;
// V: [num_blocks, num_heads, head_size, block_size]
const int64_t tgt_value_idx =
block_idx * num_heads * head_size * block_size +
head_idx * head_size * block_size + head_offset * block_size +
block_offset;
scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx];
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
key_cache[tgt_key_idx] = tgt_key;
value_cache[tgt_value_idx] = tgt_value;
} else if constexpr (kv_dt == Fp8KVCacheDataType::kInt8) {
key_cache[tgt_key_idx] =
int8::scaled_vec_conversion_int8<cache_t, scalar_t>(tgt_key,
*k_scale);
value_cache[tgt_value_idx] =
int8::scaled_vec_conversion_int8<cache_t, scalar_t>(tgt_value,
*v_scale);
} else {
key_cache[tgt_key_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
value_cache[tgt_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
}
}
}
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt> template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_flash_kernel( __global__ void reshape_and_cache_flash_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size] const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
...@@ -538,6 +598,56 @@ void reshape_and_cache( ...@@ -538,6 +598,56 @@ void reshape_and_cache(
CALL_RESHAPE_AND_CACHE) CALL_RESHAPE_AND_CACHE)
} }
#define CALL_RESHAPE_AND_CACHE_CUDA(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_kernel_cuda<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
num_heads, head_size, block_size, 1, \
reinterpret_cast<const float*>(k_scale.data_ptr()), \
reinterpret_cast<const float*>(v_scale.data_ptr()));
void reshape_and_cache_cuda(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, block_size, head_size]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
TORCH_CHECK(key.dim() == 3 && value.dim() == 3,
"key/value must be [num_tokens, num_heads, head_size]");
TORCH_CHECK(key_cache.dim() == 4 && value_cache.dim() == 4,
"cache tensor shape mismatch");
TORCH_CHECK(key_cache.size(0) == value_cache.size(0) &&
key_cache.size(1) == value_cache.size(1) &&
key_cache.size(2) == value_cache.size(3) &&
key_cache.size(3) == value_cache.size(2),
"key/value cache dimension mismatch");
int num_tokens = slot_mapping.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(2); // k layout: [num_blocks, num_heads, block_size, head_size]
int key_stride = key.stride(0);
int value_stride = value.stride(0);
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
CALL_RESHAPE_AND_CACHE_CUDA);
}
// KV_T is the data type of key and value tensors. // KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache. // CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache. // KV_DTYPE is the real data type of kv-cache.
......
...@@ -845,6 +845,15 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) { ...@@ -845,6 +845,15 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" Tensor k_scale, Tensor v_scale) -> ()"); " Tensor k_scale, Tensor v_scale) -> ()");
cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache); cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
// Reshape the key(new) and value tensors and cache them.
cache_ops.def(
"reshape_and_cache_cuda(Tensor key, Tensor value, "
"Tensor! key_cache, Tensor! value_cache, Tensor slot_mapping, "
"str kv_cache_dtype, Tensor k_scale, Tensor v_scale) -> ()");
cache_ops.impl("reshape_and_cache_cuda",
torch::kCUDA,
&reshape_and_cache_cuda);
// Reshape the key and value tensors and cache them. // Reshape the key and value tensors and cache them.
cache_ops.def( cache_ops.def(
"reshape_and_cache_flash(Tensor key, Tensor value," "reshape_and_cache_flash(Tensor key, Tensor value,"
......
...@@ -315,6 +315,16 @@ See [this page](#generative-models) for more information on how to use generativ ...@@ -315,6 +315,16 @@ See [this page](#generative-models) for more information on how to use generativ
* `deepseek-ai/DeepSeek-V3-Base`, `deepseek-ai/DeepSeek-V3` etc. * `deepseek-ai/DeepSeek-V3-Base`, `deepseek-ai/DeepSeek-V3` etc.
* *
* ✅︎ * ✅︎
- * `Ernie4_5_ForCausalLM`
* Ernie4.5
* `baidu/ERNIE-4.5-0.3B-PT`, etc.
*
* ✅︎
- * `Ernie4_5_MoeForCausalLM`
* Ernie4.5MoE
* `baidu/ERNIE-4.5-21B-A3B-PT`, `baidu/ERNIE-4.5-300B-A47B-PT`, etc.
*
* ✅︎
- * `ExaoneForCausalLM` - * `ExaoneForCausalLM`
* EXAONE-3 * EXAONE-3
* `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. * `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc.
...@@ -575,6 +585,11 @@ See [this page](#generative-models) for more information on how to use generativ ...@@ -575,6 +585,11 @@ See [this page](#generative-models) for more information on how to use generativ
* `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc. * `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.
* ✅︎ * ✅︎
* ✅︎ * ✅︎
- * `MiniMaxM1ForCausalLM`
* MiniMax-Text
* `MiniMaxAI/MiniMax-M1-40k`, etc.
*
* ✅︎
- * `MiniMaxText01ForCausalLM` - * `MiniMaxText01ForCausalLM`
* MiniMax-Text * MiniMax-Text
* `MiniMaxAI/MiniMax-Text-01`, etc. * `MiniMaxAI/MiniMax-Text-01`, etc.
......
...@@ -9,6 +9,7 @@ prompts = [ ...@@ -9,6 +9,7 @@ prompts = [
"The president of the United States is", "The president of the United States is",
"The capital of France is", "The capital of France is",
"The future of AI is", "The future of AI is",
"Hello, my name is",
] ]
# Create a sampling params object. # Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=16) sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=16)
...@@ -16,7 +17,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=16) ...@@ -16,7 +17,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=16)
def main(): def main():
# Create an LLM. # Create an LLM.
llm = LLM(model="facebook/opt-125m",tensor_parallel_size=1, dtype="float16",trust_remote_code=True, enforce_eager=True) llm = LLM(model="facebook/opt-125m",tensor_parallel_size=1, dtype="float16",trust_remote_code=True, enforce_eager=True, block_size=16, enable_prefix_caching=False)
# Generate texts from the prompts. # Generate texts from the prompts.
# The output is a list of RequestOutput objects # The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information. # that contain the prompt, generated text, and other information.
......
...@@ -209,6 +209,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -209,6 +209,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True), trust_remote_code=True),
"MiniMaxText01ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "MiniMaxAI/MiniMax-Text-01"), "MiniMaxText01ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "MiniMaxAI/MiniMax-Text-01"),
trust_remote_code=True), trust_remote_code=True),
"MiniMaxM1ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "MiniMaxAI/MiniMax-M1-40k"),
trust_remote_code=True),
"MistralForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "mistralai/Mistral-7B-Instruct-v0.1")), "MistralForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "mistralai/Mistral-7B-Instruct-v0.1")),
"MixtralForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "mistralai/Mixtral-8x7B-Instruct-v0.1"), # noqa: E501 "MixtralForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "mistralai/Mixtral-8x7B-Instruct-v0.1"), # noqa: E501
{"tiny": os.path.join(models_path_prefix, "TitanML/tiny-mixtral")}), # noqa: E501 {"tiny": os.path.join(models_path_prefix, "TitanML/tiny-mixtral")}), # noqa: E501
...@@ -257,6 +259,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = { ...@@ -257,6 +259,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
tokenizer=os.path.join(models_path_prefix, "meta-llama/Llama-2-7b"), tokenizer=os.path.join(models_path_prefix, "meta-llama/Llama-2-7b"),
trust_remote_code=True), trust_remote_code=True),
"Zamba2ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "Zyphra/Zamba2-7B-instruct")), "Zamba2ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "Zyphra/Zamba2-7B-instruct")),
"Ernie4_5_ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "baidu/ERNIE-4.5-0.3B-PT"),
trust_remote_code=True),
"Ernie4_5_MoeForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "baidu/ERNIE-4.5-21B-A3B-PT"),
trust_remote_code=True),
"MiMoForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "XiaomiMiMo/MiMo-7B-RL"), "MiMoForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "XiaomiMiMo/MiMo-7B-RL"),
trust_remote_code=True), trust_remote_code=True),
# [Encoder-decoder] # [Encoder-decoder]
......
...@@ -2070,6 +2070,21 @@ def reshape_and_cache( ...@@ -2070,6 +2070,21 @@ def reshape_and_cache(
kv_cache_dtype, k_scale, v_scale) kv_cache_dtype, k_scale, v_scale)
def reshape_and_cache_cuda(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> None:
torch.ops._C_cache_ops.reshape_and_cache_cuda(key, value, key_cache,
value_cache, slot_mapping,
kv_cache_dtype, k_scale, v_scale)
def reshape_and_cache_flash( def reshape_and_cache_flash(
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
......
...@@ -580,17 +580,17 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -580,17 +580,17 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else: else:
if SUPPORT_TC: if SUPPORT_TC:
try: try:
from flash_attn import flash_attn_varlen_func # noqa: F401 from flash_attn import flash_attn_varlen_func, vllm_flash_attn_varlen_func # , vllm_flash_attn_with_kvcache # noqa: F401
self.fa_attn_func = flash_attn_varlen_func self.fa_attn_func = flash_attn_varlen_func
if not envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN and gpuname.startswith('K100_AI'):
from flash_attn import vllm_flash_attn_varlen_func
self.fa_prefix_attn_func = vllm_flash_attn_varlen_func self.fa_prefix_attn_func = vllm_flash_attn_varlen_func
# self.fa_decode_attn_func = vllm_flash_attn_with_kvcache
logger.debug("Using CUTLASS FA in ROCmBackend") logger.debug("Using CUTLASS FA in ROCmBackend")
except ModuleNotFoundError: except ModuleNotFoundError:
self.use_naive_attn = True self.use_naive_attn = True
else: else:
self.use_naive_attn = True self.use_naive_attn = True
envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN = True
if self.use_naive_attn: if self.use_naive_attn:
if logits_soft_cap is not None: if logits_soft_cap is not None:
...@@ -857,7 +857,7 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -857,7 +857,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else: else:
# prefix-enabled attention - # prefix-enabled attention -
# not applicable for encoder-only models # not applicable for encoder-only models
if envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN or (not gpuname.startswith('K100_AI')): if envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN:
version_key = triton_key() version_key = triton_key()
if self.attn_type != AttentionType.ENCODER_ONLY: if self.attn_type != AttentionType.ENCODER_ONLY:
output[:num_prefill_tokens] = paged_attn.forward_prefix( output[:num_prefill_tokens] = paged_attn.forward_prefix(
...@@ -889,19 +889,10 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -889,19 +889,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
triton: [GPU blocks, num_kv_heads, head_size // x, block_size, x] ---> triton: [GPU blocks, num_kv_heads, head_size // x, block_size, x] --->
cutlass: num_blocks x page_block_size x num_heads_k x head_size i cutlass: num_blocks x page_block_size x num_heads_k x head_size i
''' '''
num_blocks, num_kv_heads, head_size_div_x, block_size, x = key_cache.shape
head_size = head_size_div_x * x
key_cache_flash = key_cache.permute(0, 3, 1, 2, 4) # [num_blocks, block_size, num_kv_heads, head_size//x, x]
key_cache_flash = key_cache_flash.reshape(num_blocks, block_size, num_kv_heads, head_size)
# value_cache
value_cache_flash = value_cache.permute(0, 3, 1, 2) # [num_blocks, block_size, num_kv_heads, head_size]
output[:num_prefill_tokens] = self.fa_prefix_attn_func( # noqa output[:num_prefill_tokens] = self.fa_prefix_attn_func( # noqa
q=query, q=query,
k=key_cache_flash, k=key_cache,
v=value_cache_flash, v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc, cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len, max_seqlen_q=prefill_meta.max_query_len,
seqused_k=prefill_meta.seq_lens_tensor, seqused_k=prefill_meta.seq_lens_tensor,
...@@ -977,6 +968,26 @@ class ROCmFlashAttentionImpl(AttentionImpl): ...@@ -977,6 +968,26 @@ class ROCmFlashAttentionImpl(AttentionImpl):
) )
else: else:
tree_attention_masks_tensor = decode_meta.tree_attention_masks_tensor tree_attention_masks_tensor = decode_meta.tree_attention_masks_tensor
if envs.VLLM_USE_FLASH_ATTN_PA:
from flash_attn import vllm_flash_attn_with_kvcache
# output[num_prefill_tokens:] = self.fa_decode_attn_func(
output[num_prefill_tokens:] = vllm_flash_attn_with_kvcache(
q=decode_query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
cache_seqlens=decode_meta.seq_lens_tensor,
block_table=decode_meta.block_tables,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
softcap=self.logits_soft_cap,
alibi_slopes=self.alibi_slopes,
return_softmax_lse=False,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
kv_cache_dtype=self.kv_cache_dtype,
).squeeze(1)
else:
output[num_prefill_tokens:] = paged_attn.forward_decode( output[num_prefill_tokens:] = paged_attn.forward_decode(
decode_query, decode_query,
key_cache, key_cache,
......
...@@ -59,6 +59,17 @@ class PagedAttention: ...@@ -59,6 +59,17 @@ class PagedAttention:
x = 16 // kv_cache.element_size() x = 16 // kv_cache.element_size()
num_blocks = kv_cache.shape[1] num_blocks = kv_cache.shape[1]
'''
CUTLASS key_cache layout: [num_blocks, num_kv_heads, block_size, head_size]
Triton key_cache layout: [num_blocks, num_kv_heads, head_size // x, block_size, x]
value_cache layout: [num_blocks, num_kv_heads, head_size, block_size]
'''
if envs.VLLM_USE_FLASH_ATTN_PA:
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size)
value_cache = kv_cache[1]
value_cache=value_cache.view(num_blocks, num_kv_heads,head_size, -1)
else:
key_cache = kv_cache[0] key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
-1, x) -1, x)
...@@ -77,6 +88,18 @@ class PagedAttention: ...@@ -77,6 +88,18 @@ class PagedAttention:
k_scale: torch.Tensor, k_scale: torch.Tensor,
v_scale: torch.Tensor, v_scale: torch.Tensor,
) -> None: ) -> None:
if envs.VLLM_USE_FLASH_ATTN_PA:
ops.reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
)
else:
ops.reshape_and_cache( ops.reshape_and_cache(
key, key,
value, value,
...@@ -88,6 +111,7 @@ class PagedAttention: ...@@ -88,6 +111,7 @@ class PagedAttention:
v_scale, v_scale,
) )
@staticmethod @staticmethod
def forward_decode( def forward_decode(
query: torch.Tensor, query: torch.Tensor,
......
...@@ -67,7 +67,7 @@ class FixFunctionalizationPass(VllmInductorPass): ...@@ -67,7 +67,7 @@ class FixFunctionalizationPass(VllmInductorPass):
# self.defunctionalize(graph, node, mutated_args) # self.defunctionalize(graph, node, mutated_args)
# elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501 # elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501
# mutated_args = {1: 'result', 2: 'scale', 3: 'residual'} # mutated_args = {1: 'result', 2: 'scale', 3: 'residual'}
self.defunctionalize(graph, node, mutated_args) # self.defunctionalize(graph, node, mutated_args)
elif at_target in [ elif at_target in [
torch.ops._C.rms_norm.default, torch.ops._C.rms_norm.default,
# torch.ops._C.rms_norm_static_fp8_quant.default, # torch.ops._C.rms_norm_static_fp8_quant.default,
......
...@@ -4461,9 +4461,7 @@ class VllmConfig: ...@@ -4461,9 +4461,7 @@ class VllmConfig:
self.compilation_config.cudagraph_num_of_warmups = 1 self.compilation_config.cudagraph_num_of_warmups = 1
self.compilation_config.pass_config.enable_fusion = False self.compilation_config.pass_config.enable_fusion = False
self.compilation_config.pass_config.enable_noop = False self.compilation_config.pass_config.enable_noop = False
# TODO self.compilation_config.level = CompilationLevel.PIECEWISE
# self.compilation_config.level = CompilationLevel.PIECEWISE
self.compilation_config.level = CompilationLevel.NO_COMPILATION
self.compilation_config.set_splitting_ops_for_v1() self.compilation_config.set_splitting_ops_for_v1()
self._set_cudagraph_sizes() self._set_cudagraph_sizes()
......
...@@ -143,13 +143,14 @@ if TYPE_CHECKING: ...@@ -143,13 +143,14 @@ if TYPE_CHECKING:
VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX: int = 16 VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX: int = 16
VLLM_ENFORCE_EAGER_BS_THRESHOLD: Optional[int] = None VLLM_ENFORCE_EAGER_BS_THRESHOLD: Optional[int] = None
VLLM_HAS_CONTEXT_DEFAULT: bool = False VLLM_HAS_CONTEXT_DEFAULT: bool = False
VLLM_FLASH_ATTN_BACKEND: bool = False VLLM_FLASH_ATTN_V1: bool = False
VLLM_USE_NN: bool = False VLLM_USE_NN: bool = False
VLLM_ENABLE_TBO: bool = False VLLM_ENABLE_TBO: bool = False
VLLM_TBO_REQ_DELAY_MS: int = 0 VLLM_TBO_REQ_DELAY_MS: int = 0
VLLM_TBO_DECODE_BS: int = 0 VLLM_TBO_DECODE_BS: int = 0
VLLM_ZERO_OVERHEAD: bool = False VLLM_ZERO_OVERHEAD: bool = False
VLLM_ENABLE_MOE_FUSED_GATE: bool = False VLLM_ENABLE_MOE_FUSED_GATE: bool = False
VLLM_USE_FLASH_ATTN_PA: bool = False
def get_default_cache_root(): def get_default_cache_root():
...@@ -961,14 +962,14 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -961,14 +962,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_HAS_CONTEXT_DEFAULT": "VLLM_HAS_CONTEXT_DEFAULT":
lambda: bool(int(os.getenv("VLLM_HAS_CONTEXT_DEFAULT", "0"))), lambda: bool(int(os.getenv("VLLM_HAS_CONTEXT_DEFAULT", "0"))),
# If set, vLLM will use FlashAttention Backend for attention computation on rocm # If set, vLLM will use FlashAttention Backend for v1 attention computation on rocm
"VLLM_FLASH_ATTN_BACKEND": "VLLM_FLASH_ATTN_V1":
lambda: (os.environ.get("VLLM_FLASH_ATTN_BACKEND", "False").lower() in lambda: (os.environ.get("VLLM_FLASH_ATTN_V1", "False").lower() in
("true", "1")), ("true", "1")),
# If set, vLLM will transpose weight to use nn layout # If set, vLLM will transpose weight to use nn layout
"VLLM_USE_NN": "VLLM_USE_NN":
lambda: (os.environ.get("VLLM_USE_NN", "False").lower() in lambda: (os.environ.get("VLLM_USE_NN", "True").lower() in
("true", "1")), ("true", "1")),
# Enable two batch overlap. # Enable two batch overlap.
...@@ -990,6 +991,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -990,6 +991,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# If set, vLLM will enable the moe_fused_gate kernel. # If set, vLLM will enable the moe_fused_gate kernel.
"VLLM_ENABLE_MOE_FUSED_GATE": "VLLM_ENABLE_MOE_FUSED_GATE":
lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_FUSED_GATE", "1"))), lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_FUSED_GATE", "1"))),
# vLLM will use FlashAttention Backend for page attention computation on rocm
"VLLM_USE_FLASH_ATTN_PA":
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_PA", "False").lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -139,6 +139,7 @@ class BlockInt8LinearMethod(LinearMethodBase): ...@@ -139,6 +139,7 @@ class BlockInt8LinearMethod(LinearMethodBase):
assert self.quant_config.weight_block_size is not None assert self.quant_config.weight_block_size is not None
assert self.quant_config.is_checkpoint_int8_serialized assert self.quant_config.is_checkpoint_int8_serialized
self.tritonsingleton= W8a8GetCacheJSON()
def create_weights( def create_weights(
self, self,
...@@ -231,8 +232,8 @@ class BlockInt8LinearMethod(LinearMethodBase): ...@@ -231,8 +232,8 @@ class BlockInt8LinearMethod(LinearMethodBase):
n=layer.weight.shape[0] n=layer.weight.shape[0]
k=layer.weight.shape[1] k=layer.weight.shape[1]
if {n,k} not in self.tritonsingleton.weight_shapes: if [n,k] not in self.tritonsingleton.weight_shapes:
self.tritonsingleton.weight_shapes.append({n,k}) self.tritonsingleton.weight_shapes.append([n,k])
json_file=self.tritonsingleton.get_blockint8json_name(n,k,self.block_size[0],self.block_size[1]) json_file=self.tritonsingleton.get_blockint8json_name(n,k,self.block_size[0],self.block_size[1])
configs_dict=self.tritonsingleton.get_blockint8_triton_cache(json_file,n,k,self.block_size[0],self.block_size[1]) configs_dict=self.tritonsingleton.get_blockint8_triton_cache(json_file,n,k,self.block_size[0],self.block_size[1])
...@@ -260,7 +261,6 @@ class BlockInt8LinearMethod(LinearMethodBase): ...@@ -260,7 +261,6 @@ class BlockInt8LinearMethod(LinearMethodBase):
K=x.shape[1] K=x.shape[1]
N=layer.weight.shape[0] N=layer.weight.shape[0]
#print("self.tritonsingleton.triton_json_dict:",self.tritonsingleton.triton_json_dict)
#Get the best config options #Get the best config options
if len(self.tritonsingleton.triton_json_dict)==0: if len(self.tritonsingleton.triton_json_dict)==0:
config=None config=None
...@@ -293,8 +293,6 @@ class BlockInt8LinearMethod(LinearMethodBase): ...@@ -293,8 +293,6 @@ class BlockInt8LinearMethod(LinearMethodBase):
else: else:
config=None config=None
#print("m:{},n:{},k:{},config:{}".format(M,N,K,config))
return apply_w8a8_block_int8_linear( return apply_w8a8_block_int8_linear(
input=x, input=x,
weight=layer.weight, weight=layer.weight,
...@@ -431,6 +429,26 @@ class BlockInt8MoEMethod: ...@@ -431,6 +429,26 @@ class BlockInt8MoEMethod:
def process_weights_after_loading(self, layer: Module) -> None: def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading # Block quant doesn't need to process weights after loading
# warmup and get moe block-int8 config # warmup and get moe block-int8 config
E=layer.w13_weight.shape[0]
N1=layer.w13_weight.shape[1]
N2=layer.w2_weight.shape[1]
K=layer.w2_weight.shape[2]
if [E,N1,N2,K] not in self.tritonsingleton.moe_weight_shapes:
self.tritonsingleton.moe_weight_shapes.append([E,N1,N2,K])
TOPK= self.tritonsingleton.topk
block_size=self.quant_config.weight_block_size
json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK,block_size,)
configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK)
#warmup
if configs_dict:
self.tritonsingleton.triton_moejson_dict.update(configs_dict)
#生成模型配置文件
self.tritonsingleton.gen_model_json(block_size)
return return
def apply( def apply(
......
...@@ -597,8 +597,8 @@ class CompressedTensorsLinearMethod(LinearMethodBase): ...@@ -597,8 +597,8 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
k=layer.weight.shape[1] k=layer.weight.shape[1]
if self.w8a8_strategy==1: if self.w8a8_strategy==1:
if {n,k} not in self.tritonsingleton.weight_shapes: if [n,k] not in self.tritonsingleton.weight_shapes:
self.tritonsingleton.weight_shapes.append({n,k}) self.tritonsingleton.weight_shapes.append([n,k])
json_file=self.tritonsingleton.get_w8a8json_name(n,k) json_file=self.tritonsingleton.get_w8a8json_name(n,k)
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k) configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
...@@ -607,12 +607,13 @@ class CompressedTensorsLinearMethod(LinearMethodBase): ...@@ -607,12 +607,13 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
for key, value in configs_dict.items(): for key, value in configs_dict.items():
m=int(key.split('_')[0]) m=int(key.split('_')[0])
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value) ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,device=layer.weight.device,best_config=value)
else: else:
weight_data=layer.weight.data weight_data=layer.weight.data
_weight=weight_data.T.contiguous().reshape(n,-1) _weight=weight_data.T.contiguous().reshape(n,-1)
layer.weight.data=_weight layer.weight.data=_weight
self.tritonsingleton.gen_model_json()
layer.scheme.process_weights_after_loading(layer) layer.scheme.process_weights_after_loading(layer)
def create_weights(self, layer: torch.nn.Module, def create_weights(self, layer: torch.nn.Module,
......
...@@ -29,6 +29,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( ...@@ -29,6 +29,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types from vllm.scalar_type import scalar_types
from vllm.utils import W8a8GetCacheJSON
has_pplx = importlib.util.find_spec("pplx_kernels") is not None has_pplx = importlib.util.find_spec("pplx_kernels") is not None
...@@ -141,6 +142,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -141,6 +142,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
is_rocm_aiter_moe_enabled) is_rocm_aiter_moe_enabled)
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled() self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
self.tritonsingleton= W8a8GetCacheJSON()
def create_weights(self, layer: torch.nn.Module, num_experts: int, def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int, hidden_size: int, intermediate_size_per_partition: int,
...@@ -226,6 +228,27 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod): ...@@ -226,6 +228,27 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
layer.w13_input_scale = None layer.w13_input_scale = None
layer.w2_input_scale = None layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
E=layer.w13_weight.shape[0]
N1=layer.w13_weight.shape[1]
N2=layer.w2_weight.shape[1]
K=layer.w2_weight.shape[2]
if [E,N1,N2,K] not in self.tritonsingleton.moe_weight_shapes:
self.tritonsingleton.moe_weight_shapes.append([E,N1,N2,K])
TOPK= self.tritonsingleton.topk
json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK)
configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK)
#warmup
if configs_dict:
self.tritonsingleton.triton_moejson_dict.update(configs_dict)
#生成模型配置文件
#self.tritonsingleton.gen_model_json(block_size)
return
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Fp8 moe kernels require a single activation scale. # Fp8 moe kernels require a single activation scale.
# We take the max of all the scales in case they differ. # We take the max of all the scales in case they differ.
......
...@@ -14,6 +14,8 @@ from vllm.platforms import current_platform ...@@ -14,6 +14,8 @@ from vllm.platforms import current_platform
from .ScaledMMLinearKernel import (ScaledMMLinearKernel, from .ScaledMMLinearKernel import (ScaledMMLinearKernel,
ScaledMMLinearLayerConfig) ScaledMMLinearLayerConfig)
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
...@@ -112,6 +114,11 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel): ...@@ -112,6 +114,11 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
# * dynamic, i_s is None and x_s computed from x. # * dynamic, i_s is None and x_s computed from x.
# * static, i_s is scalar and x_s is i_s. # * static, i_s is scalar and x_s is i_s.
symmetric = azp_adj is None symmetric = azp_adj is None
if i_s is None and i_zp is None and symmetric is True:
x_q, x_s=per_token_quant_int8(x)
x_zp =None
else:
x_q, x_s, x_zp = ops.scaled_int8_quant(x.contiguous(), x_q, x_s, x_zp = ops.scaled_int8_quant(x.contiguous(),
i_s, i_s,
i_zp, i_zp,
......
...@@ -10,6 +10,7 @@ from vllm import envs ...@@ -10,6 +10,7 @@ from vllm import envs
from vllm.config import CompilationLevel, get_current_vllm_config from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import W8a8GetCacheJSON from vllm.utils import W8a8GetCacheJSON
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
# Input scaling factors are no longer optional in _scaled_mm starting # Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale # from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
...@@ -396,6 +397,11 @@ def apply_int8_linear( ...@@ -396,6 +397,11 @@ def apply_int8_linear(
# * dynamic, layer.input_scale is None and x_scale computed from x. # * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale. # * static, layer.input_scale is scalar and x_scale is input_scale.
symmetric = azp_adj is None symmetric = azp_adj is None
if input_scale is None and input_zero_point is None and symmetric is True:
x_q, x_scale=per_token_quant_int8(input)
x_zp =None
else:
x_q, x_scale, x_zp = ops.scaled_int8_quant(input, x_q, x_scale, x_zp = ops.scaled_int8_quant(input,
input_scale, input_scale,
input_zero_point, input_zero_point,
......
...@@ -248,6 +248,7 @@ class W8A8Int8MoEMethod: ...@@ -248,6 +248,7 @@ class W8A8Int8MoEMethod:
def __init__(self, quant_config): def __init__(self, quant_config):
self.quant_config = quant_config self.quant_config = quant_config
self.tritonsingleton= W8a8GetCacheJSON()
def create_weights( def create_weights(
self, self,
...@@ -302,6 +303,22 @@ class W8A8Int8MoEMethod: ...@@ -302,6 +303,22 @@ class W8A8Int8MoEMethod:
layer.register_parameter("w2_input_scale", w2_input_scale) layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None: def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
E=layer.w13_weight.shape[0]
N1=layer.w13_weight.shape[1]
N2=layer.w2_weight.shape[1]
K=layer.w2_weight.shape[2]
if [E,N1,N2,K] not in self.tritonsingleton.moe_weight_shapes:
self.tritonsingleton.moe_weight_shapes.append([E,N1,N2,K])
TOPK= self.tritonsingleton.topk
json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK)
configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK)
#warmup
if configs_dict:
self.tritonsingleton.triton_moejson_dict.update(configs_dict)
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False) layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False) layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
layer.w13_weight_scale = Parameter( layer.w13_weight_scale = Parameter(
......
...@@ -27,6 +27,7 @@ from .deepseek_v2 import (DeepseekV2DecoderLayer, ...@@ -27,6 +27,7 @@ from .deepseek_v2 import (DeepseekV2DecoderLayer,
from .interfaces import SupportsPP from .interfaces import SupportsPP
from .utils import maybe_prefix from .utils import maybe_prefix
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config
class SharedHead(nn.Module): class SharedHead(nn.Module):
...@@ -164,6 +165,9 @@ class DeepSeekMTP(nn.Module, SupportsPP): ...@@ -164,6 +165,9 @@ class DeepSeekMTP(nn.Module, SupportsPP):
self.quant_method = quant_config.get_name() self.quant_method = quant_config.get_name()
os.environ['LLAMA_NN'] = '0' os.environ['LLAMA_NN'] = '0'
os.environ['LM_NN'] = '0' os.environ['LM_NN'] = '0'
# The AWQ layer of MTP uses BlockInt8W8A8.
if self.quant_method == "moe_wna16":
vllm_config.quant_config = BlockInt8Config(is_checkpoint_int8_serialized=True, weight_block_size=[128,128])
self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config, self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config,
prefix=maybe_prefix( prefix=maybe_prefix(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment