Unverified Commit 5e6ddfd6 authored by OlivierDehaene's avatar OlivierDehaene Committed by GitHub
Browse files

fix(server): fix llamav2 config (#635)

parent cf83f9b6
...@@ -23,6 +23,7 @@ import torch.distributed ...@@ -23,6 +23,7 @@ import torch.distributed
from torch import nn from torch import nn
from transformers.activations import ACT2FN from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from typing import Optional, List, Tuple from typing import Optional, List, Tuple
# Flash attention imports # Flash attention imports
...@@ -43,6 +44,56 @@ from text_generation_server.utils.layers import ( ...@@ -43,6 +44,56 @@ from text_generation_server.utils.layers import (
) )
class LlamaConfig(PretrainedConfig):
def __init__(
self,
vocab_size=32000,
hidden_size=4096,
intermediate_size=11008,
num_hidden_layers=32,
num_attention_heads=32,
num_key_value_heads=None,
hidden_act="silu",
max_position_embeddings=2048,
initializer_range=0.02,
rms_norm_eps=1e-6,
use_cache=True,
pad_token_id=0,
bos_token_id=1,
eos_token_id=2,
pretraining_tp=1,
tie_word_embeddings=False,
rope_scaling=None,
**kwargs,
):
self.vocab_size = vocab_size
self.max_position_embeddings = max_position_embeddings
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
# for backward compatibility
if num_key_value_heads is None:
num_key_value_heads = num_attention_heads
self.num_key_value_heads = num_key_value_heads
self.hidden_act = hidden_act
self.initializer_range = initializer_range
self.rms_norm_eps = rms_norm_eps
self.pretraining_tp = pretraining_tp
self.use_cache = use_cache
self.rope_scaling = rope_scaling
super().__init__(
pad_token_id=pad_token_id,
bos_token_id=bos_token_id,
eos_token_id=eos_token_id,
tie_word_embeddings=tie_word_embeddings,
**kwargs,
)
class LlamaRMSNorm(nn.Module): class LlamaRMSNorm(nn.Module):
def __init__(self, prefix, weights, eps=1e-6): def __init__(self, prefix, weights, eps=1e-6):
""" """
......
...@@ -2,13 +2,13 @@ import torch ...@@ -2,13 +2,13 @@ import torch
import torch.distributed import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoConfig
from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast from transformers.models.llama import LlamaTokenizer, LlamaTokenizerFast
from typing import Optional from typing import Optional
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.custom_modeling.flash_llama_modeling import ( from text_generation_server.models.custom_modeling.flash_llama_modeling import (
FlashLlamaForCausalLM, FlashLlamaForCausalLM,
LlamaConfig,
) )
from text_generation_server.utils import ( from text_generation_server.utils import (
initialize_torch_distributed, initialize_torch_distributed,
...@@ -52,7 +52,7 @@ class FlashLlama(FlashCausalLM): ...@@ -52,7 +52,7 @@ class FlashLlama(FlashCausalLM):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
) )
config = AutoConfig.from_pretrained( config = LlamaConfig.from_pretrained(
model_id, revision=revision, trust_remote_code=trust_remote_code model_id, revision=revision, trust_remote_code=trust_remote_code
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment