Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
5e6ddfd6
Unverified
Commit
5e6ddfd6
authored
Jul 18, 2023
by
OlivierDehaene
Committed by
GitHub
Jul 18, 2023
Browse files
fix(server): fix llamav2 config (#635)
parent
cf83f9b6
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
53 additions
and
2 deletions
+53
-2
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
...ion_server/models/custom_modeling/flash_llama_modeling.py
+51
-0
server/text_generation_server/models/flash_llama.py
server/text_generation_server/models/flash_llama.py
+2
-2
No files found.
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
View file @
5e6ddfd6
...
...
@@ -23,6 +23,7 @@ import torch.distributed
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
transformers.configuration_utils
import
PretrainedConfig
from
typing
import
Optional
,
List
,
Tuple
# Flash attention imports
...
...
@@ -43,6 +44,56 @@ from text_generation_server.utils.layers import (
)
class
LlamaConfig
(
PretrainedConfig
):
def
__init__
(
self
,
vocab_size
=
32000
,
hidden_size
=
4096
,
intermediate_size
=
11008
,
num_hidden_layers
=
32
,
num_attention_heads
=
32
,
num_key_value_heads
=
None
,
hidden_act
=
"silu"
,
max_position_embeddings
=
2048
,
initializer_range
=
0.02
,
rms_norm_eps
=
1e-6
,
use_cache
=
True
,
pad_token_id
=
0
,
bos_token_id
=
1
,
eos_token_id
=
2
,
pretraining_tp
=
1
,
tie_word_embeddings
=
False
,
rope_scaling
=
None
,
**
kwargs
,
):
self
.
vocab_size
=
vocab_size
self
.
max_position_embeddings
=
max_position_embeddings
self
.
hidden_size
=
hidden_size
self
.
intermediate_size
=
intermediate_size
self
.
num_hidden_layers
=
num_hidden_layers
self
.
num_attention_heads
=
num_attention_heads
# for backward compatibility
if
num_key_value_heads
is
None
:
num_key_value_heads
=
num_attention_heads
self
.
num_key_value_heads
=
num_key_value_heads
self
.
hidden_act
=
hidden_act
self
.
initializer_range
=
initializer_range
self
.
rms_norm_eps
=
rms_norm_eps
self
.
pretraining_tp
=
pretraining_tp
self
.
use_cache
=
use_cache
self
.
rope_scaling
=
rope_scaling
super
().
__init__
(
pad_token_id
=
pad_token_id
,
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
tie_word_embeddings
=
tie_word_embeddings
,
**
kwargs
,
)
class
LlamaRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
prefix
,
weights
,
eps
=
1e-6
):
"""
...
...
server/text_generation_server/models/flash_llama.py
View file @
5e6ddfd6
...
...
@@ -2,13 +2,13 @@ import torch
import
torch.distributed
from
opentelemetry
import
trace
from
transformers
import
AutoConfig
from
transformers.models.llama
import
LlamaTokenizer
,
LlamaTokenizerFast
from
typing
import
Optional
from
text_generation_server.models
import
FlashCausalLM
from
text_generation_server.models.custom_modeling.flash_llama_modeling
import
(
FlashLlamaForCausalLM
,
LlamaConfig
,
)
from
text_generation_server.utils
import
(
initialize_torch_distributed
,
...
...
@@ -52,7 +52,7 @@ class FlashLlama(FlashCausalLM):
trust_remote_code
=
trust_remote_code
,
)
config
=
Auto
Config
.
from_pretrained
(
config
=
Llama
Config
.
from_pretrained
(
model_id
,
revision
=
revision
,
trust_remote_code
=
trust_remote_code
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment