Commit bed32c8d authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.9.2-dev-ds' of...

Merge branch 'v0.9.2-dev-ds' of ssh://10.16.6.30:10022/dcutoolkit/deeplearing/vllm into v0.9.2-dev-ds
parents 5ca1c279 0e92caa0
...@@ -81,7 +81,7 @@ class Attention(nn.Module): ...@@ -81,7 +81,7 @@ class Attention(nn.Module):
calculate_kv_scales = cache_config.calculate_kv_scales calculate_kv_scales = cache_config.calculate_kv_scales
else: else:
kv_cache_dtype = "auto" kv_cache_dtype = "auto"
block_size = 64 if envs.VLLM_USE_FLASH_ATTN_PA or envs.VLLM_USE_FLASH_MLA else 16 block_size = 64 if envs.VLLM_USE_FLASH_ATTN_PA and envs.VLLM_USE_FLASH_MLA else 16
is_attention_free = False is_attention_free = False
calculate_kv_scales = False calculate_kv_scales = False
if num_kv_heads is None: if num_kv_heads is None:
...@@ -312,7 +312,7 @@ class MultiHeadAttention(nn.Module): ...@@ -312,7 +312,7 @@ class MultiHeadAttention(nn.Module):
attn_backend = get_attn_backend(head_size, attn_backend = get_attn_backend(head_size,
dtype, dtype,
kv_cache_dtype=None, kv_cache_dtype=None,
block_size=64 if envs.VLLM_USE_FLASH_ATTN_PA or envs.VLLM_USE_FLASH_MLA else 16, block_size=64 if envs.VLLM_USE_FLASH_ATTN_PA and envs.VLLM_USE_FLASH_MLA else 16,
is_attention_free=False) is_attention_free=False)
backend = backend_name_to_enum(attn_backend.get_name()) backend = backend_name_to_enum(attn_backend.get_name())
if current_platform.is_rocm(): if current_platform.is_rocm():
......
...@@ -1499,7 +1499,7 @@ PrefixCachingHashAlgo = Literal["builtin", "sha256"] ...@@ -1499,7 +1499,7 @@ PrefixCachingHashAlgo = Literal["builtin", "sha256"]
class CacheConfig: class CacheConfig:
"""Configuration for the KV cache.""" """Configuration for the KV cache."""
block_size: BlockSize = 64 if envs.VLLM_USE_FLASH_ATTN_PA or envs.VLLM_USE_FLASH_MLA else 16 # type: ignore block_size: BlockSize = 64 if envs.VLLM_USE_FLASH_ATTN_PA and envs.VLLM_USE_FLASH_MLA else 16 # type: ignore
"""Size of a contiguous cache block in number of tokens. This is ignored on """Size of a contiguous cache block in number of tokens. This is ignored on
neuron devices and set to `--max-model-len`. On CUDA devices, only block neuron devices and set to `--max-model-len`. On CUDA devices, only block
sizes up to 32 are supported. On HPU devices, block size defaults to 128. sizes up to 32 are supported. On HPU devices, block size defaults to 128.
......
...@@ -242,11 +242,6 @@ class DeepseekV2MoE(nn.Module): ...@@ -242,11 +242,6 @@ class DeepseekV2MoE(nn.Module):
# See DeepseekV2DecoderLayer for more details. # See DeepseekV2DecoderLayer for more details.
final_hidden_states = self.experts(hidden_states=hidden_states, final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits) router_logits=router_logits)
else:
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if not self.use_mori_ep:
if shared_output is not None: if shared_output is not None:
if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick: if hidden_states.dtype != torch.float16 or self.dpsk_fp16_quick:
final_hidden_states = final_hidden_states + shared_output final_hidden_states = final_hidden_states + shared_output
...@@ -255,7 +250,11 @@ class DeepseekV2MoE(nn.Module): ...@@ -255,7 +250,11 @@ class DeepseekV2MoE(nn.Module):
# See DeepseekV2DecoderLayer for more details. # See DeepseekV2DecoderLayer for more details.
final_hidden_states = final_hidden_states + shared_output \ final_hidden_states = final_hidden_states + shared_output \
* (1. / self.routed_scaling_factor) * (1. / self.routed_scaling_factor)
else:
final_hidden_states = self.experts(hidden_states=hidden_states,
router_logits=router_logits)
if not self.use_mori_ep:
if self.tp_size > 1: if self.tp_size > 1:
if envs.VLLM_ENABLE_TBO: if envs.VLLM_ENABLE_TBO:
final_hidden_states = self.tbo_all_reduce(final_hidden_states) final_hidden_states = self.tbo_all_reduce(final_hidden_states)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment