Unverified Commit 8a02cd04 authored by Yongzao's avatar Yongzao Committed by GitHub
Browse files

[torch.compile] Adding torch compile annotations to some models (#9639)


Signed-off-by: default avataryoukaichao <youkaichao@gmail.com>
Co-authored-by: default avataryoukaichao <youkaichao@gmail.com>
parent 4fdc581f
...@@ -144,7 +144,7 @@ Text Generation ...@@ -144,7 +144,7 @@ Text Generation
- ✅︎ - ✅︎
* - :code:`JAISLMHeadModel` * - :code:`JAISLMHeadModel`
- Jais - Jais
- :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc. - :code:`inceptionai/jais-13b`, :code:`inceptionai/jais-13b-chat`, :code:`inceptionai/jais-30b-v3`, :code:`inceptionai/jais-30b-chat-v3`, etc.
- -
- ✅︎ - ✅︎
* - :code:`JambaForCausalLM` * - :code:`JambaForCausalLM`
......
...@@ -145,7 +145,7 @@ TEXT_GENERATION_MODELS = { ...@@ -145,7 +145,7 @@ TEXT_GENERATION_MODELS = {
# Uses Llama # Uses Llama
# "internlm/internlm-chat-7b": PPTestSettings.fast(), # "internlm/internlm-chat-7b": PPTestSettings.fast(),
"internlm/internlm2-chat-7b": PPTestSettings.fast(trust_remote_code=True), "internlm/internlm2-chat-7b": PPTestSettings.fast(trust_remote_code=True),
"core42/jais-13b-chat": PPTestSettings.fast(), "inceptionai/jais-13b-chat": PPTestSettings.fast(),
# TODO: Implement PP # TODO: Implement PP
# "ai21labs/AI21-Jamba-1.5-Mini": PPTestSettings.fast(), # "ai21labs/AI21-Jamba-1.5-Mini": PPTestSettings.fast(),
"meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(), "meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(),
......
# coding=utf-8 # coding=utf-8
# Adapted from # Adapted from
# https://huggingface.co/core42/jais-30b-chat-v3/blob/main/modeling_jais.py # https://huggingface.co/inceptionai/jais-30b-chat-v3/blob/main/modeling_jais.py
# Copyright 2023 The vLLM team. # Copyright 2023 The vLLM team.
# Copyright 2023 the Jais authors and HuggingFace Inc. team. All rights # Copyright 2023 the Jais authors and HuggingFace Inc. team. All rights
# reserved. # reserved.
...@@ -26,6 +26,7 @@ import torch ...@@ -26,6 +26,7 @@ import torch
from torch import nn from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
...@@ -212,6 +213,7 @@ class JAISBlock(nn.Module): ...@@ -212,6 +213,7 @@ class JAISBlock(nn.Module):
return hidden_states return hidden_states
@support_torch_compile
class JAISModel(nn.Module): class JAISModel(nn.Module):
def __init__( def __init__(
......
...@@ -29,6 +29,7 @@ from torch import nn ...@@ -29,6 +29,7 @@ from torch import nn
from transformers import PretrainedConfig from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size, get_tensor_model_parallel_world_size,
...@@ -348,6 +349,7 @@ class MiniCPMDecoderLayer(nn.Module): ...@@ -348,6 +349,7 @@ class MiniCPMDecoderLayer(nn.Module):
return hidden_states, None return hidden_states, None
@support_torch_compile
class MiniCPMModel(nn.Module): class MiniCPMModel(nn.Module):
def __init__( def __init__(
......
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size) get_tensor_model_parallel_world_size)
...@@ -204,6 +205,7 @@ class MPTBlock(nn.Module): ...@@ -204,6 +205,7 @@ class MPTBlock(nn.Module):
return hidden_states return hidden_states
@support_torch_compile
class MPTModel(nn.Module): class MPTModel(nn.Module):
def __init__( def __init__(
......
...@@ -27,6 +27,7 @@ import torch ...@@ -27,6 +27,7 @@ import torch
from torch import nn from torch import nn
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, LoRAConfig from vllm.config import CacheConfig, LoRAConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.activation import get_act_fn
...@@ -290,6 +291,7 @@ class NemotronDecoderLayer(nn.Module): ...@@ -290,6 +291,7 @@ class NemotronDecoderLayer(nn.Module):
return hidden_states, residual return hidden_states, residual
@support_torch_compile
class NemotronModel(nn.Module): class NemotronModel(nn.Module):
def __init__( def __init__(
......
...@@ -28,6 +28,7 @@ from torch import nn ...@@ -28,6 +28,7 @@ from torch import nn
from transformers import OlmoConfig from transformers import OlmoConfig
from vllm.attention import Attention, AttentionMetadata from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig from vllm.config import CacheConfig
from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size
from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.activation import SiluAndMul
...@@ -221,6 +222,7 @@ class OlmoDecoderLayer(nn.Module): ...@@ -221,6 +222,7 @@ class OlmoDecoderLayer(nn.Module):
return hidden_states return hidden_states
@support_torch_compile
class OlmoModel(nn.Module): class OlmoModel(nn.Module):
def __init__(self, def __init__(self,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment