Commit 9e27b5e4 authored by 王敏's avatar 王敏
Browse files

Merge remote-tracking branch 'origin/v0.9.1-dev' into v0.9.1-dev

parents 504c262e b2fa85ce
......@@ -965,18 +965,14 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO
)
max_num_partitions=1;
int blocks=max_num_partitions*batchsize*qheads;
if(device_name=="gfx928"){
if(batchsize*qheads>1024&&max_seq_len>=2000){
max_num_partitions=1;
if(max_seq_len<2000)reusekv=8;
else if(max_seq_len<3900)reusekv=4;
else{
PARTITION_SIZE=2048;
reusekv=8;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
}
return;
if(batchsize>100&&max_seq_len>=2000){
if(max_seq_len<3900)reusekv=4;
else{
PARTITION_SIZE=1024;
reusekv=8;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
}
return;
}
if(max_num_partitions==1){
if(max_seq_len<512){
......@@ -995,17 +991,14 @@ void get_numberthread_and_reuse_kv_v2(int& num_thread,int& reusekv,int& PARTITIO
if(blocks<600||qheads<=kvheads*4){reusekv=4;return;}
reusekv=8;return;
}
if(device_name=="gfx928"){
if(batchsize*qheads>1024&&max_seq_len>=2000){
max_num_partitions=1;
if(max_seq_len<3900)reusekv=4;
else{
PARTITION_SIZE=2048;
reusekv=4;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
}
return;
if(batchsize>100&&max_seq_len>=2000){
if(max_seq_len<3900)reusekv=4;
else{
PARTITION_SIZE=2048;
reusekv=4;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
}
return;
}
if(max_seq_len<=1000||
max_seq_len<=1500&&(qheads>4&&batchsize>=16||batchsize>=64))
......@@ -1068,6 +1061,7 @@ void paged_attention_v2_launcher_opt_tc(
int reusekv, num_thread,max_num_partitions,PARTITION_SIZE=512;
if(!is_half&&max_seq_len<=8192)PARTITION_SIZE=256;
get_numberthread_and_reuse_kv_v2(num_thread,reusekv,PARTITION_SIZE,max_num_partitions,num_seqs,max_seq_len,num_heads,num_kv_heads,num_blocks);
if(num_seqs>100&&max_num_partitions>16)max_num_partitions=16;
if(PA_PARTITION_SIZE!=0){
PARTITION_SIZE=PA_PARTITION_SIZE;
max_num_partitions=DIVIDE_ROUND_UP(max_seq_len, PARTITION_SIZE);
......
......@@ -24,6 +24,12 @@ void reshape_and_cache(torch::Tensor& key, torch::Tensor& value,
const std::string& kv_cache_dtype,
torch::Tensor& k_scale, torch::Tensor& v_scale);
void reshape_and_cache_cuda(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache, torch::Tensor& value_cache,
torch::Tensor& slot_mapping,
const std::string& kv_cache_dtype,
torch::Tensor& k_scale, torch::Tensor& v_scale);
void reshape_and_cache_flash(torch::Tensor& key, torch::Tensor& value,
torch::Tensor& key_cache,
torch::Tensor& value_cache,
......
......@@ -270,6 +270,66 @@ __global__ void reshape_and_cache_kernel(
}
}
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_kernel_cuda(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
const scalar_t* __restrict__ value, // [num_tokens, num_heads, head_size]
cache_t* __restrict__ key_cache, // [num_blocks, num_heads, block_size, head_size] target layout
cache_t* __restrict__ value_cache, // [num_blocks, num_heads, head_size, block_size]
const int64_t* __restrict__ slot_mapping, // [num_tokens]
const int key_stride, const int value_stride, const int num_heads,
const int head_size, const int block_size, int x,
const float* k_scale, const float* v_scale) {
const int64_t token_idx = blockIdx.x;
const int64_t slot_idx = slot_mapping[token_idx];
if (slot_idx < 0) {
// Padding token that should be ignored.
return;
}
const int64_t block_idx = slot_idx / block_size;
const int64_t block_offset = slot_idx % block_size;
const int n = num_heads * head_size;
for (int i = threadIdx.x; i < n; i += blockDim.x) {
const int64_t src_key_idx = token_idx * key_stride + i;
const int64_t src_value_idx = token_idx * value_stride + i;
const int head_idx = i / head_size;
const int head_offset = i % head_size;
// ---------- calculate target index ----------
// K: [num_blocks, num_heads, block_size, head_size]
const int64_t tgt_key_idx =
block_idx * num_heads * block_size * head_size +
head_idx * block_size * head_size + block_offset * head_size +
head_offset;
// V: [num_blocks, num_heads, head_size, block_size]
const int64_t tgt_value_idx =
block_idx * num_heads * head_size * block_size +
head_idx * head_size * block_size + head_offset * block_size +
block_offset;
scalar_t tgt_key = key[src_key_idx];
scalar_t tgt_value = value[src_value_idx];
if constexpr (kv_dt == Fp8KVCacheDataType::kAuto) {
key_cache[tgt_key_idx] = tgt_key;
value_cache[tgt_value_idx] = tgt_value;
} else if constexpr (kv_dt == Fp8KVCacheDataType::kInt8) {
key_cache[tgt_key_idx] =
int8::scaled_vec_conversion_int8<cache_t, scalar_t>(tgt_key,
*k_scale);
value_cache[tgt_value_idx] =
int8::scaled_vec_conversion_int8<cache_t, scalar_t>(tgt_value,
*v_scale);
} else {
key_cache[tgt_key_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_key, *k_scale);
value_cache[tgt_value_idx] =
fp8::scaled_convert<cache_t, scalar_t, kv_dt>(tgt_value, *v_scale);
}
}
}
template <typename scalar_t, typename cache_t, Fp8KVCacheDataType kv_dt>
__global__ void reshape_and_cache_flash_kernel(
const scalar_t* __restrict__ key, // [num_tokens, num_heads, head_size]
......@@ -538,6 +598,56 @@ void reshape_and_cache(
CALL_RESHAPE_AND_CACHE)
}
#define CALL_RESHAPE_AND_CACHE_CUDA(KV_T, CACHE_T, KV_DTYPE) \
vllm::reshape_and_cache_kernel_cuda<KV_T, CACHE_T, KV_DTYPE> \
<<<grid, block, 0, stream>>>( \
reinterpret_cast<KV_T*>(key.data_ptr()), \
reinterpret_cast<KV_T*>(value.data_ptr()), \
reinterpret_cast<CACHE_T*>(key_cache.data_ptr()), \
reinterpret_cast<CACHE_T*>(value_cache.data_ptr()), \
slot_mapping.data_ptr<int64_t>(), key_stride, value_stride, \
num_heads, head_size, block_size, 1, \
reinterpret_cast<const float*>(k_scale.data_ptr()), \
reinterpret_cast<const float*>(v_scale.data_ptr()));
void reshape_and_cache_cuda(
torch::Tensor& key, // [num_tokens, num_heads, head_size]
torch::Tensor& value, // [num_tokens, num_heads, head_size]
torch::Tensor&
key_cache, // [num_blocks, num_heads, block_size, head_size]
torch::Tensor&
value_cache, // [num_blocks, num_heads, head_size, block_size]
torch::Tensor& slot_mapping, // [num_tokens]
const std::string& kv_cache_dtype, torch::Tensor& k_scale,
torch::Tensor& v_scale) {
TORCH_CHECK(key.dim() == 3 && value.dim() == 3,
"key/value must be [num_tokens, num_heads, head_size]");
TORCH_CHECK(key_cache.dim() == 4 && value_cache.dim() == 4,
"cache tensor shape mismatch");
TORCH_CHECK(key_cache.size(0) == value_cache.size(0) &&
key_cache.size(1) == value_cache.size(1) &&
key_cache.size(2) == value_cache.size(3) &&
key_cache.size(3) == value_cache.size(2),
"key/value cache dimension mismatch");
int num_tokens = slot_mapping.size(0);
int num_heads = key.size(1);
int head_size = key.size(2);
int block_size = key_cache.size(2); // k layout: [num_blocks, num_heads, block_size, head_size]
int key_stride = key.stride(0);
int value_stride = value.stride(0);
dim3 grid(num_tokens);
dim3 block(std::min(num_heads * head_size, 512));
const at::cuda::OptionalCUDAGuard device_guard(device_of(key));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
DISPATCH_BY_KV_CACHE_DTYPE(key.dtype(), kv_cache_dtype,
CALL_RESHAPE_AND_CACHE_CUDA);
}
// KV_T is the data type of key and value tensors.
// CACHE_T is the stored data type of kv-cache.
// KV_DTYPE is the real data type of kv-cache.
......
......@@ -845,6 +845,15 @@ TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cache_ops), cache_ops) {
" Tensor k_scale, Tensor v_scale) -> ()");
cache_ops.impl("reshape_and_cache", torch::kCUDA, &reshape_and_cache);
// Reshape the key(new) and value tensors and cache them.
cache_ops.def(
"reshape_and_cache_cuda(Tensor key, Tensor value, "
"Tensor! key_cache, Tensor! value_cache, Tensor slot_mapping, "
"str kv_cache_dtype, Tensor k_scale, Tensor v_scale) -> ()");
cache_ops.impl("reshape_and_cache_cuda",
torch::kCUDA,
&reshape_and_cache_cuda);
// Reshape the key and value tensors and cache them.
cache_ops.def(
"reshape_and_cache_flash(Tensor key, Tensor value,"
......
......@@ -315,6 +315,16 @@ See [this page](#generative-models) for more information on how to use generativ
* `deepseek-ai/DeepSeek-V3-Base`, `deepseek-ai/DeepSeek-V3` etc.
*
* ✅︎
- * `Ernie4_5_ForCausalLM`
* Ernie4.5
* `baidu/ERNIE-4.5-0.3B-PT`, etc.
*
* ✅︎
- * `Ernie4_5_MoeForCausalLM`
* Ernie4.5MoE
* `baidu/ERNIE-4.5-21B-A3B-PT`, `baidu/ERNIE-4.5-300B-A47B-PT`, etc.
*
* ✅︎
- * `ExaoneForCausalLM`
* EXAONE-3
* `LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc.
......@@ -575,6 +585,11 @@ See [this page](#generative-models) for more information on how to use generativ
* `xverse/XVERSE-7B-Chat`, `xverse/XVERSE-13B-Chat`, `xverse/XVERSE-65B-Chat`, etc.
* ✅︎
* ✅︎
- * `MiniMaxM1ForCausalLM`
* MiniMax-Text
* `MiniMaxAI/MiniMax-M1-40k`, etc.
*
* ✅︎
- * `MiniMaxText01ForCausalLM`
* MiniMax-Text
* `MiniMaxAI/MiniMax-Text-01`, etc.
......
......@@ -9,6 +9,7 @@ prompts = [
"The president of the United States is",
"The capital of France is",
"The future of AI is",
"Hello, my name is",
]
# Create a sampling params object.
sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=16)
......@@ -16,7 +17,7 @@ sampling_params = SamplingParams(temperature=0.8, top_p=0.95, max_tokens=16)
def main():
# Create an LLM.
llm = LLM(model="facebook/opt-125m",tensor_parallel_size=1, dtype="float16",trust_remote_code=True, enforce_eager=True)
llm = LLM(model="facebook/opt-125m",tensor_parallel_size=1, dtype="float16",trust_remote_code=True, enforce_eager=True, block_size=16, enable_prefix_caching=False)
# Generate texts from the prompts.
# The output is a list of RequestOutput objects
# that contain the prompt, generated text, and other information.
......
......@@ -209,6 +209,8 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
trust_remote_code=True),
"MiniMaxText01ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "MiniMaxAI/MiniMax-Text-01"),
trust_remote_code=True),
"MiniMaxM1ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "MiniMaxAI/MiniMax-M1-40k"),
trust_remote_code=True),
"MistralForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "mistralai/Mistral-7B-Instruct-v0.1")),
"MixtralForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "mistralai/Mixtral-8x7B-Instruct-v0.1"), # noqa: E501
{"tiny": os.path.join(models_path_prefix, "TitanML/tiny-mixtral")}), # noqa: E501
......@@ -257,6 +259,10 @@ _TEXT_GENERATION_EXAMPLE_MODELS = {
tokenizer=os.path.join(models_path_prefix, "meta-llama/Llama-2-7b"),
trust_remote_code=True),
"Zamba2ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "Zyphra/Zamba2-7B-instruct")),
"Ernie4_5_ForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "baidu/ERNIE-4.5-0.3B-PT"),
trust_remote_code=True),
"Ernie4_5_MoeForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "baidu/ERNIE-4.5-21B-A3B-PT"),
trust_remote_code=True),
"MiMoForCausalLM": _HfExamplesInfo(os.path.join(models_path_prefix, "XiaomiMiMo/MiMo-7B-RL"),
trust_remote_code=True),
# [Encoder-decoder]
......
......@@ -2070,6 +2070,21 @@ def reshape_and_cache(
kv_cache_dtype, k_scale, v_scale)
def reshape_and_cache_cuda(
key: torch.Tensor,
value: torch.Tensor,
key_cache: torch.Tensor,
value_cache: torch.Tensor,
slot_mapping: torch.Tensor,
kv_cache_dtype: str,
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> None:
torch.ops._C_cache_ops.reshape_and_cache_cuda(key, value, key_cache,
value_cache, slot_mapping,
kv_cache_dtype, k_scale, v_scale)
def reshape_and_cache_flash(
key: torch.Tensor,
value: torch.Tensor,
......
......@@ -580,17 +580,17 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else:
if SUPPORT_TC:
try:
from flash_attn import flash_attn_varlen_func # noqa: F401
from flash_attn import flash_attn_varlen_func, vllm_flash_attn_varlen_func # , vllm_flash_attn_with_kvcache # noqa: F401
self.fa_attn_func = flash_attn_varlen_func
if not envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN and gpuname.startswith('K100_AI'):
from flash_attn import vllm_flash_attn_varlen_func
self.fa_prefix_attn_func = vllm_flash_attn_varlen_func
self.fa_prefix_attn_func = vllm_flash_attn_varlen_func
# self.fa_decode_attn_func = vllm_flash_attn_with_kvcache
logger.debug("Using CUTLASS FA in ROCmBackend")
except ModuleNotFoundError:
self.use_naive_attn = True
else:
self.use_naive_attn = True
envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN = True
if self.use_naive_attn:
if logits_soft_cap is not None:
......@@ -857,7 +857,7 @@ class ROCmFlashAttentionImpl(AttentionImpl):
else:
# prefix-enabled attention -
# not applicable for encoder-only models
if envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN or (not gpuname.startswith('K100_AI')):
if envs.VLLM_USE_TRITON_PREFIX_FLASH_ATTN:
version_key = triton_key()
if self.attn_type != AttentionType.ENCODER_ONLY:
output[:num_prefill_tokens] = paged_attn.forward_prefix(
......@@ -889,19 +889,10 @@ class ROCmFlashAttentionImpl(AttentionImpl):
triton: [GPU blocks, num_kv_heads, head_size // x, block_size, x] --->
cutlass: num_blocks x page_block_size x num_heads_k x head_size i
'''
num_blocks, num_kv_heads, head_size_div_x, block_size, x = key_cache.shape
head_size = head_size_div_x * x
key_cache_flash = key_cache.permute(0, 3, 1, 2, 4) # [num_blocks, block_size, num_kv_heads, head_size//x, x]
key_cache_flash = key_cache_flash.reshape(num_blocks, block_size, num_kv_heads, head_size)
# value_cache
value_cache_flash = value_cache.permute(0, 3, 1, 2) # [num_blocks, block_size, num_kv_heads, head_size]
output[:num_prefill_tokens] = self.fa_prefix_attn_func( # noqa
q=query,
k=key_cache_flash,
v=value_cache_flash,
k=key_cache,
v=value_cache,
cu_seqlens_q=prefill_meta.query_start_loc,
max_seqlen_q=prefill_meta.max_query_len,
seqused_k=prefill_meta.seq_lens_tensor,
......@@ -977,28 +968,48 @@ class ROCmFlashAttentionImpl(AttentionImpl):
)
else:
tree_attention_masks_tensor = decode_meta.tree_attention_masks_tensor
output[num_prefill_tokens:] = paged_attn.forward_decode(
decode_query,
key_cache,
value_cache,
decode_meta.block_tables
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.cross_block_tables,
decode_meta.seq_lens_tensor
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.encoder_seq_lens_tensor,
decode_meta.max_decode_seq_len
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.max_encoder_seq_len,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
layer._k_scale,
layer._v_scale,
attn_masks=tree_attention_masks_tensor,
attn_masks_stride=tree_attention_masks_tensor.stride(0) if tree_attention_masks_tensor is not None else 0
)
if envs.VLLM_USE_FLASH_ATTN_PA:
from flash_attn import vllm_flash_attn_with_kvcache
# output[num_prefill_tokens:] = self.fa_decode_attn_func(
output[num_prefill_tokens:] = vllm_flash_attn_with_kvcache(
q=decode_query.unsqueeze(1),
k_cache=key_cache,
v_cache=value_cache,
cache_seqlens=decode_meta.seq_lens_tensor,
block_table=decode_meta.block_tables,
softmax_scale=self.scale,
causal=True,
window_size=self.sliding_window,
softcap=self.logits_soft_cap,
alibi_slopes=self.alibi_slopes,
return_softmax_lse=False,
k_scale=layer._k_scale,
v_scale=layer._v_scale,
kv_cache_dtype=self.kv_cache_dtype,
).squeeze(1)
else:
output[num_prefill_tokens:] = paged_attn.forward_decode(
decode_query,
key_cache,
value_cache,
decode_meta.block_tables
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.cross_block_tables,
decode_meta.seq_lens_tensor
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.encoder_seq_lens_tensor,
decode_meta.max_decode_seq_len
if self.attn_type != AttentionType.ENCODER_DECODER else
decode_meta.max_encoder_seq_len,
self.kv_cache_dtype,
self.num_kv_heads,
self.scale,
self.alibi_slopes,
layer._k_scale,
layer._v_scale,
attn_masks=tree_attention_masks_tensor,
attn_masks_stride=tree_attention_masks_tensor.stride(0) if tree_attention_masks_tensor is not None else 0
)
# Reshape the output tensor.
return output.view(-1, self.num_heads * self.head_size)
......
......@@ -58,12 +58,23 @@ class PagedAttention:
) -> Tuple[torch.Tensor, torch.Tensor]:
x = 16 // kv_cache.element_size()
num_blocks = kv_cache.shape[1]
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
-1, x)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
'''
CUTLASS key_cache layout: [num_blocks, num_kv_heads, block_size, head_size]
Triton key_cache layout: [num_blocks, num_kv_heads, head_size // x, block_size, x]
value_cache layout: [num_blocks, num_kv_heads, head_size, block_size]
'''
if envs.VLLM_USE_FLASH_ATTN_PA:
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, -1, head_size)
value_cache = kv_cache[1]
value_cache=value_cache.view(num_blocks, num_kv_heads,head_size, -1)
else:
key_cache = kv_cache[0]
key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x,
-1, x)
value_cache = kv_cache[1]
value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1)
return key_cache, value_cache
@staticmethod
......@@ -77,16 +88,29 @@ class PagedAttention:
k_scale: torch.Tensor,
v_scale: torch.Tensor,
) -> None:
ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
)
if envs.VLLM_USE_FLASH_ATTN_PA:
ops.reshape_and_cache_cuda(
key,
value,
key_cache,
value_cache,
slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
)
else:
ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
slot_mapping.flatten(),
kv_cache_dtype,
k_scale,
v_scale,
)
@staticmethod
def forward_decode(
......
......@@ -67,7 +67,7 @@ class FixFunctionalizationPass(VllmInductorPass):
# self.defunctionalize(graph, node, mutated_args)
# elif at_target == torch.ops._C.rms_norm_dynamic_per_token_quant.default: # noqa: E501
# mutated_args = {1: 'result', 2: 'scale', 3: 'residual'}
self.defunctionalize(graph, node, mutated_args)
# self.defunctionalize(graph, node, mutated_args)
elif at_target in [
torch.ops._C.rms_norm.default,
# torch.ops._C.rms_norm_static_fp8_quant.default,
......
......@@ -4461,9 +4461,7 @@ class VllmConfig:
self.compilation_config.cudagraph_num_of_warmups = 1
self.compilation_config.pass_config.enable_fusion = False
self.compilation_config.pass_config.enable_noop = False
# TODO
# self.compilation_config.level = CompilationLevel.PIECEWISE
self.compilation_config.level = CompilationLevel.NO_COMPILATION
self.compilation_config.level = CompilationLevel.PIECEWISE
self.compilation_config.set_splitting_ops_for_v1()
self._set_cudagraph_sizes()
......
......@@ -143,13 +143,14 @@ if TYPE_CHECKING:
VLLM_CUSTOM_ALLREDUCE_SUPPORTED_WORLDSIZE_MAX: int = 16
VLLM_ENFORCE_EAGER_BS_THRESHOLD: Optional[int] = None
VLLM_HAS_CONTEXT_DEFAULT: bool = False
VLLM_FLASH_ATTN_BACKEND: bool = False
VLLM_FLASH_ATTN_V1: bool = False
VLLM_USE_NN: bool = False
VLLM_ENABLE_TBO: bool = False
VLLM_TBO_REQ_DELAY_MS: int = 0
VLLM_TBO_DECODE_BS: int = 0
VLLM_ZERO_OVERHEAD: bool = False
VLLM_ENABLE_MOE_FUSED_GATE: bool = False
VLLM_USE_FLASH_ATTN_PA: bool = False
def get_default_cache_root():
......@@ -961,14 +962,14 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_HAS_CONTEXT_DEFAULT":
lambda: bool(int(os.getenv("VLLM_HAS_CONTEXT_DEFAULT", "0"))),
# If set, vLLM will use FlashAttention Backend for attention computation on rocm
"VLLM_FLASH_ATTN_BACKEND":
lambda: (os.environ.get("VLLM_FLASH_ATTN_BACKEND", "False").lower() in
# If set, vLLM will use FlashAttention Backend for v1 attention computation on rocm
"VLLM_FLASH_ATTN_V1":
lambda: (os.environ.get("VLLM_FLASH_ATTN_V1", "False").lower() in
("true", "1")),
# If set, vLLM will transpose weight to use nn layout
"VLLM_USE_NN":
lambda: (os.environ.get("VLLM_USE_NN", "False").lower() in
lambda: (os.environ.get("VLLM_USE_NN", "True").lower() in
("true", "1")),
# Enable two batch overlap.
......@@ -990,6 +991,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
# If set, vLLM will enable the moe_fused_gate kernel.
"VLLM_ENABLE_MOE_FUSED_GATE":
lambda: bool(int(os.getenv("VLLM_ENABLE_MOE_FUSED_GATE", "1"))),
# vLLM will use FlashAttention Backend for page attention computation on rocm
"VLLM_USE_FLASH_ATTN_PA":
lambda: (os.environ.get("VLLM_USE_FLASH_ATTN_PA", "False").lower() in
("true", "1")),
}
# --8<-- [end:env-vars-definition]
......
......@@ -139,6 +139,7 @@ class BlockInt8LinearMethod(LinearMethodBase):
assert self.quant_config.weight_block_size is not None
assert self.quant_config.is_checkpoint_int8_serialized
self.tritonsingleton= W8a8GetCacheJSON()
def create_weights(
self,
......@@ -231,8 +232,8 @@ class BlockInt8LinearMethod(LinearMethodBase):
n=layer.weight.shape[0]
k=layer.weight.shape[1]
if {n,k} not in self.tritonsingleton.weight_shapes:
self.tritonsingleton.weight_shapes.append({n,k})
if [n,k] not in self.tritonsingleton.weight_shapes:
self.tritonsingleton.weight_shapes.append([n,k])
json_file=self.tritonsingleton.get_blockint8json_name(n,k,self.block_size[0],self.block_size[1])
configs_dict=self.tritonsingleton.get_blockint8_triton_cache(json_file,n,k,self.block_size[0],self.block_size[1])
......@@ -260,7 +261,6 @@ class BlockInt8LinearMethod(LinearMethodBase):
K=x.shape[1]
N=layer.weight.shape[0]
#print("self.tritonsingleton.triton_json_dict:",self.tritonsingleton.triton_json_dict)
#Get the best config options
if len(self.tritonsingleton.triton_json_dict)==0:
config=None
......@@ -292,9 +292,7 @@ class BlockInt8LinearMethod(LinearMethodBase):
else:
config=None
#print("m:{},n:{},k:{},config:{}".format(M,N,K,config))
return apply_w8a8_block_int8_linear(
input=x,
weight=layer.weight,
......@@ -431,6 +429,26 @@ class BlockInt8MoEMethod:
def process_weights_after_loading(self, layer: Module) -> None:
# Block quant doesn't need to process weights after loading
# warmup and get moe block-int8 config
E=layer.w13_weight.shape[0]
N1=layer.w13_weight.shape[1]
N2=layer.w2_weight.shape[1]
K=layer.w2_weight.shape[2]
if [E,N1,N2,K] not in self.tritonsingleton.moe_weight_shapes:
self.tritonsingleton.moe_weight_shapes.append([E,N1,N2,K])
TOPK= self.tritonsingleton.topk
block_size=self.quant_config.weight_block_size
json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK,block_size,)
configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK)
#warmup
if configs_dict:
self.tritonsingleton.triton_moejson_dict.update(configs_dict)
#生成模型配置文件
self.tritonsingleton.gen_model_json(block_size)
return
def apply(
......
......@@ -597,8 +597,8 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
k=layer.weight.shape[1]
if self.w8a8_strategy==1:
if {n,k} not in self.tritonsingleton.weight_shapes:
self.tritonsingleton.weight_shapes.append({n,k})
if [n,k] not in self.tritonsingleton.weight_shapes:
self.tritonsingleton.weight_shapes.append([n,k])
json_file=self.tritonsingleton.get_w8a8json_name(n,k)
configs_dict=self.tritonsingleton.get_triton_cache(json_file,n,k)
......@@ -607,12 +607,13 @@ class CompressedTensorsLinearMethod(LinearMethodBase):
for key, value in configs_dict.items():
m=int(key.split('_')[0])
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,best_config=value)
ops.triton_int8_gemm_helper(m=m,n=n,k=k,per_token_act_quant=True,per_out_channel_weight_quant=True,use_bias=False,device=layer.weight.device,best_config=value)
else:
weight_data=layer.weight.data
_weight=weight_data.T.contiguous().reshape(n,-1)
layer.weight.data=_weight
self.tritonsingleton.gen_model_json()
layer.scheme.process_weights_after_loading(layer)
def create_weights(self, layer: torch.nn.Module,
......
......@@ -29,6 +29,7 @@ from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.scalar_type import scalar_types
from vllm.utils import W8a8GetCacheJSON
has_pplx = importlib.util.find_spec("pplx_kernels") is not None
......@@ -141,6 +142,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
is_rocm_aiter_moe_enabled)
self.rocm_aiter_moe_enabled = is_rocm_aiter_moe_enabled()
self.tritonsingleton= W8a8GetCacheJSON()
def create_weights(self, layer: torch.nn.Module, num_experts: int,
hidden_size: int, intermediate_size_per_partition: int,
......@@ -225,6 +227,27 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
else:
layer.w13_input_scale = None
layer.w2_input_scale = None
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
E=layer.w13_weight.shape[0]
N1=layer.w13_weight.shape[1]
N2=layer.w2_weight.shape[1]
K=layer.w2_weight.shape[2]
if [E,N1,N2,K] not in self.tritonsingleton.moe_weight_shapes:
self.tritonsingleton.moe_weight_shapes.append([E,N1,N2,K])
TOPK= self.tritonsingleton.topk
json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK)
configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK)
#warmup
if configs_dict:
self.tritonsingleton.triton_moejson_dict.update(configs_dict)
#生成模型配置文件
#self.tritonsingleton.gen_model_json(block_size)
return
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# Fp8 moe kernels require a single activation scale.
......
......@@ -14,6 +14,8 @@ from vllm.platforms import current_platform
from .ScaledMMLinearKernel import (ScaledMMLinearKernel,
ScaledMMLinearLayerConfig)
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
......@@ -112,10 +114,15 @@ class CutlassScaledMMLinearKernel(ScaledMMLinearKernel):
# * dynamic, i_s is None and x_s computed from x.
# * static, i_s is scalar and x_s is i_s.
symmetric = azp_adj is None
x_q, x_s, x_zp = ops.scaled_int8_quant(x.contiguous(),
i_s,
i_zp,
symmetric=symmetric)
if i_s is None and i_zp is None and symmetric is True:
x_q, x_s=per_token_quant_int8(x)
x_zp =None
else:
x_q, x_s, x_zp = ops.scaled_int8_quant(x.contiguous(),
i_s,
i_zp,
symmetric=symmetric)
if x_zp is not None:
# Currently, static is always per-tensor and dynamic is per-token
......
......@@ -10,6 +10,7 @@ from vllm import envs
from vllm.config import CompilationLevel, get_current_vllm_config
from vllm.platforms import current_platform
from vllm.utils import W8a8GetCacheJSON
from lmslim.layers.gemm.int8_utils import per_token_quant_int8
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
......@@ -396,10 +397,15 @@ def apply_int8_linear(
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * static, layer.input_scale is scalar and x_scale is input_scale.
symmetric = azp_adj is None
x_q, x_scale, x_zp = ops.scaled_int8_quant(input,
input_scale,
input_zero_point,
symmetric=symmetric)
if input_scale is None and input_zero_point is None and symmetric is True:
x_q, x_scale=per_token_quant_int8(input)
x_zp =None
else:
x_q, x_scale, x_zp = ops.scaled_int8_quant(input,
input_scale,
input_zero_point,
symmetric=symmetric)
if x_zp is not None:
# Currently, static is always per-tensor and dynamic is per-token
......
......@@ -248,6 +248,7 @@ class W8A8Int8MoEMethod:
def __init__(self, quant_config):
self.quant_config = quant_config
self.tritonsingleton= W8a8GetCacheJSON()
def create_weights(
self,
......@@ -302,6 +303,22 @@ class W8A8Int8MoEMethod:
layer.register_parameter("w2_input_scale", w2_input_scale)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
E=layer.w13_weight.shape[0]
N1=layer.w13_weight.shape[1]
N2=layer.w2_weight.shape[1]
K=layer.w2_weight.shape[2]
if [E,N1,N2,K] not in self.tritonsingleton.moe_weight_shapes:
self.tritonsingleton.moe_weight_shapes.append([E,N1,N2,K])
TOPK= self.tritonsingleton.topk
json_file=self.tritonsingleton.get_moeint8json_name(E,N1,N2,K,TOPK)
configs_dict=self.tritonsingleton.get_moeint8_triton_cache(json_file,E,N1,N2,K,TOPK)
#warmup
if configs_dict:
self.tritonsingleton.triton_moejson_dict.update(configs_dict)
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
layer.w13_weight_scale = Parameter(
......
......@@ -27,6 +27,7 @@ from .deepseek_v2 import (DeepseekV2DecoderLayer,
from .interfaces import SupportsPP
from .utils import maybe_prefix
from vllm import _custom_ops as ops
from vllm.model_executor.layers.quantization.blockwise_int8 import BlockInt8Config
class SharedHead(nn.Module):
......@@ -164,6 +165,9 @@ class DeepSeekMTP(nn.Module, SupportsPP):
self.quant_method = quant_config.get_name()
os.environ['LLAMA_NN'] = '0'
os.environ['LM_NN'] = '0'
# The AWQ layer of MTP uses BlockInt8W8A8.
if self.quant_method == "moe_wna16":
vllm_config.quant_config = BlockInt8Config(is_checkpoint_int8_serialized=True, weight_block_size=[128,128])
self.model = DeepSeekMultiTokenPredictor(vllm_config=vllm_config,
prefix=maybe_prefix(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment