Commit a014d6a5 authored by zhuwenwen's avatar zhuwenwen
Browse files

update qwen3_moe of layernorm and activation

parent 8d6b0b0a
...@@ -29,7 +29,7 @@ try: ...@@ -29,7 +29,7 @@ try:
tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe, ) tag_cudagraph_unsafe = (torch._C.Tag.cudagraph_unsafe, )
except AttributeError: except AttributeError:
tag_cudagraph_unsafe = () # type: ignore[assignment] tag_cudagraph_unsafe = () # type: ignore[assignment]
class Attention(nn.Module): class Attention(nn.Module):
"""Attention layer. """Attention layer.
...@@ -220,8 +220,8 @@ class Attention(nn.Module): ...@@ -220,8 +220,8 @@ class Attention(nn.Module):
output_shape = (output_shape output_shape = (output_shape
if output_shape is not None else query.shape) if output_shape is not None else query.shape)
output = torch.zeros(output_shape, output = torch.zeros(output_shape,
dtype=query.dtype, dtype=query.dtype,
device=query.device) device=query.device)
hidden_size = output_shape[-1] hidden_size = output_shape[-1]
# We skip reshaping query, key and value tensors for the MLA # We skip reshaping query, key and value tensors for the MLA
# backend since these tensors have different semantics and are # backend since these tensors have different semantics and are
......
...@@ -1124,7 +1124,7 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1124,7 +1124,7 @@ environment_variables: dict[str, Callable[[], Any]] = {
# vLLM will use lightop moe_align_block_size # vLLM will use lightop moe_align_block_size
"VLLM_USE_LIGHTOP_MOE_ALIGN": "VLLM_USE_LIGHTOP_MOE_ALIGN":
lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_ALIGN", "True").lower() in lambda: (os.environ.get("VLLM_USE_LIGHTOP_MOE_ALIGN", "True").lower() in
("true", "1")), ("true", "1")),
# vLLM will use opt merge_aatn_states, not triton # vLLM will use opt merge_aatn_states, not triton
"VLLM_USE_MERGE_ATTN_STATES_OPT": "VLLM_USE_MERGE_ATTN_STATES_OPT":
lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in lambda: (os.environ.get("VLLM_USE_MERGE_ATTN_STATES_OPT", "True").lower() in
......
...@@ -77,7 +77,7 @@ class SiluAndMul(CustomOp): ...@@ -77,7 +77,7 @@ class SiluAndMul(CustomOp):
"""PyTorch-native implementation equivalent to forward().""" """PyTorch-native implementation equivalent to forward()."""
if not torch.compiler.is_compiling() and envs.VLLM_ENABLE_TBO: if not torch.compiler.is_compiling() and envs.VLLM_ENABLE_TBO:
return self.forward_cuda(x) return self.forward_cuda(x)
elif envs.VLLM_USE_OPT_OP: elif not torch.compiler.is_compiling() and envs.VLLM_USE_OPT_OP:
return self.forward_cuda(x) return self.forward_cuda(x)
else: else:
d = x.shape[-1] // 2 d = x.shape[-1] // 2
......
...@@ -167,7 +167,7 @@ class RMSNorm(CustomOp): ...@@ -167,7 +167,7 @@ class RMSNorm(CustomOp):
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
if not torch.compiler.is_compiling() and envs.VLLM_ENABLE_TBO: if not torch.compiler.is_compiling() and envs.VLLM_ENABLE_TBO:
return self.forward_cuda(x, residual) return self.forward_cuda(x, residual)
elif envs.VLLM_USE_OPT_OP: elif not torch.compiler.is_compiling() and envs.VLLM_USE_OPT_OP:
return self.forward_cuda(x, residual) return self.forward_cuda(x, residual)
else: else:
orig_dtype = x.dtype orig_dtype = x.dtype
......
...@@ -234,7 +234,7 @@ class Qwen3MoeAttention(nn.Module): ...@@ -234,7 +234,7 @@ class Qwen3MoeAttention(nn.Module):
if envs.VLLM_USE_APEX_RN: if envs.VLLM_USE_APEX_RN:
q_by_head = self.q_norm.forward_apex(q_by_head) q_by_head = self.q_norm.forward_apex(q_by_head)
else: else:
q_by_head = self.q_norm(q_by_head) q_by_head = self.q_norm.forward_cuda(q_by_head)
q = q_by_head.view(q.shape) q = q_by_head.view(q.shape)
k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim,
...@@ -242,7 +242,7 @@ class Qwen3MoeAttention(nn.Module): ...@@ -242,7 +242,7 @@ class Qwen3MoeAttention(nn.Module):
if envs.VLLM_USE_APEX_RN: if envs.VLLM_USE_APEX_RN:
k_by_head = self.k_norm.forward_apex(k_by_head) k_by_head = self.k_norm.forward_apex(k_by_head)
else: else:
k_by_head = self.k_norm(k_by_head) k_by_head = self.k_norm.forward_cuda(k_by_head)
k = k_by_head.view(k.shape) k = k_by_head.view(k.shape)
q, k = self.rotary_emb(positions, q, k) q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v) attn_output = self.attn(q, k, v)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment