Unverified Commit c6573698 authored by Jinzhen Lin's avatar Jinzhen Lin Committed by GitHub
Browse files

support `torch.compile` for bailing moe (#21664)

parent 6c66f28f
...@@ -32,6 +32,7 @@ from torch import nn ...@@ -32,6 +32,7 @@ from torch import nn
from transformers.configuration_utils import PretrainedConfig from transformers.configuration_utils import PretrainedConfig
from vllm.attention import Attention from vllm.attention import Attention
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig from vllm.config import CacheConfig, VllmConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
...@@ -291,6 +292,7 @@ class BailingMoeBlock(nn.Module): ...@@ -291,6 +292,7 @@ class BailingMoeBlock(nn.Module):
return hidden_states, residual return hidden_states, residual
@support_torch_compile
class BailingMoeModel(nn.Module): class BailingMoeModel(nn.Module):
def __init__( def __init__(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment