Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
e71471be
Unverified
Commit
e71471be
authored
May 15, 2023
by
OlivierDehaene
Committed by
GitHub
May 15, 2023
Browse files
feat: add snapshot testing (#282)
parent
f58f0a03
Changes
35
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
72 additions
and
501 deletions
+72
-501
launcher/tests/bloom_560m.json
launcher/tests/bloom_560m.json
+0
-142
launcher/tests/integration_tests.rs
launcher/tests/integration_tests.rs
+0
-172
launcher/tests/mt0_base.json
launcher/tests/mt0_base.json
+0
-137
server/text_generation_server/models/bloom.py
server/text_generation_server/models/bloom.py
+1
-1
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
...ion_server/models/custom_modeling/flash_llama_modeling.py
+8
-10
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
...tion_server/models/custom_modeling/flash_neox_modeling.py
+7
-10
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
...erver/models/custom_modeling/flash_santacoder_modeling.py
+7
-9
server/text_generation_server/models/flash_llama.py
server/text_generation_server/models/flash_llama.py
+10
-5
server/text_generation_server/models/flash_neox.py
server/text_generation_server/models/flash_neox.py
+11
-3
server/text_generation_server/models/flash_santacoder.py
server/text_generation_server/models/flash_santacoder.py
+13
-5
server/text_generation_server/models/galactica.py
server/text_generation_server/models/galactica.py
+1
-1
server/text_generation_server/models/gpt_neox.py
server/text_generation_server/models/gpt_neox.py
+1
-1
server/text_generation_server/models/opt.py
server/text_generation_server/models/opt.py
+5
-2
server/text_generation_server/models/t5.py
server/text_generation_server/models/t5.py
+1
-1
server/text_generation_server/utils/layers.py
server/text_generation_server/utils/layers.py
+7
-2
No files found.
launcher/tests/bloom_560m.json
deleted
100644 → 0
View file @
f58f0a03
{
"generated_text"
:
".get(
\"
action
\"
);
\n
if (action == null) {
\n
throw new RuntimeException"
,
"details"
:
{
"finish_reason"
:
"length"
,
"generated_tokens"
:
20
,
"seed"
:
null
,
"prefill"
:
[
{
"id"
:
10264
,
"text"
:
"Test"
,
"logprob"
:
null
},
{
"id"
:
8821
,
"text"
:
" request"
,
"logprob"
:
-11.894989
}
],
"tokens"
:
[
{
"id"
:
17
,
"text"
:
"."
,
"logprob"
:
-1.8267672
,
"special"
:
false
},
{
"id"
:
1587
,
"text"
:
"get"
,
"logprob"
:
-2.4674969
,
"special"
:
false
},
{
"id"
:
11
,
"text"
:
"("
,
"logprob"
:
-1.906001
,
"special"
:
false
},
{
"id"
:
5
,
"text"
:
"
\"
"
,
"logprob"
:
-1.2279545
,
"special"
:
false
},
{
"id"
:
4899
,
"text"
:
"action"
,
"logprob"
:
-4.170299
,
"special"
:
false
},
{
"id"
:
5
,
"text"
:
"
\"
"
,
"logprob"
:
-0.32478866
,
"special"
:
false
},
{
"id"
:
12
,
"text"
:
")"
,
"logprob"
:
-1.0773665
,
"special"
:
false
},
{
"id"
:
30
,
"text"
:
";"
,
"logprob"
:
-0.27640742
,
"special"
:
false
},
{
"id"
:
837
,
"text"
:
"
\n
"
,
"logprob"
:
-1.6970354
,
"special"
:
false
},
{
"id"
:
1320
,
"text"
:
" if"
,
"logprob"
:
-1.4495516
,
"special"
:
false
},
{
"id"
:
375
,
"text"
:
" ("
,
"logprob"
:
-0.23609057
,
"special"
:
false
},
{
"id"
:
4899
,
"text"
:
"action"
,
"logprob"
:
-1.1916996
,
"special"
:
false
},
{
"id"
:
3535
,
"text"
:
" =="
,
"logprob"
:
-0.8918753
,
"special"
:
false
},
{
"id"
:
5109
,
"text"
:
" null"
,
"logprob"
:
-0.3933342
,
"special"
:
false
},
{
"id"
:
12
,
"text"
:
")"
,
"logprob"
:
-0.43212673
,
"special"
:
false
},
{
"id"
:
731
,
"text"
:
" {"
,
"logprob"
:
-0.17702064
,
"special"
:
false
},
{
"id"
:
1260
,
"text"
:
"
\n
"
,
"logprob"
:
-0.07027565
,
"special"
:
false
},
{
"id"
:
10519
,
"text"
:
" throw"
,
"logprob"
:
-1.3915029
,
"special"
:
false
},
{
"id"
:
2084
,
"text"
:
" new"
,
"logprob"
:
-0.04201372
,
"special"
:
false
},
{
"id"
:
150858
,
"text"
:
" RuntimeException"
,
"logprob"
:
-1.7329919
,
"special"
:
false
}
]
}
}
\ No newline at end of file
launcher/tests/integration_tests.rs
deleted
100644 → 0
View file @
f58f0a03
use
float_eq
::
assert_float_eq
;
use
serde
::
Deserialize
;
use
serde_json
::
Value
;
use
std
::
fs
::
File
;
use
std
::
io
::{
BufRead
,
BufReader
};
use
std
::
path
::
PathBuf
;
use
std
::
thread
;
use
std
::
thread
::
sleep
;
use
std
::
time
::
Duration
;
use
subprocess
::{
Popen
,
PopenConfig
,
Redirection
};
#[derive(Deserialize)]
pub
struct
Token
{
id
:
u32
,
text
:
String
,
logprob
:
Option
<
f32
>
,
special
:
bool
,
}
#[derive(Deserialize)]
struct
Details
{
finish_reason
:
String
,
generated_tokens
:
u32
,
tokens
:
Vec
<
Token
>
,
}
#[derive(Deserialize)]
struct
GeneratedText
{
generated_text
:
String
,
details
:
Details
,
}
fn
start_launcher
(
model_id
:
String
,
num_shard
:
usize
,
port
:
usize
,
master_port
:
usize
)
->
Popen
{
let
argv
=
vec!
[
"text-generation-launcher"
.to_string
(),
"--model-id"
.to_string
(),
model_id
.clone
(),
"--num-shard"
.to_string
(),
num_shard
.to_string
(),
"--port"
.to_string
(),
port
.to_string
(),
"--master-port"
.to_string
(),
master_port
.to_string
(),
"--shard-uds-path"
.to_string
(),
format!
(
"/tmp/test-{}-{}-{}"
,
num_shard
,
port
,
master_port
),
];
let
mut
launcher
=
Popen
::
create
(
&
argv
,
PopenConfig
{
stdout
:
Redirection
::
Pipe
,
stderr
:
Redirection
::
Merge
,
..
Default
::
default
()
},
)
.expect
(
"Could not start launcher"
);
// Redirect STDOUT and STDERR to the console
// (STDERR is merged into STDOUT)
let
launcher_stdout
=
launcher
.stdout
.take
()
.unwrap
();
thread
::
spawn
(
move
||
{
let
stdout
=
BufReader
::
new
(
launcher_stdout
);
for
line
in
stdout
.lines
()
{
println!
(
"{}"
,
line
.unwrap
());
}
});
for
_
in
0
..
60
{
let
health
=
reqwest
::
blocking
::
get
(
format!
(
"http://localhost:{}/health"
,
port
));
if
health
.is_ok
()
{
return
launcher
;
}
sleep
(
Duration
::
from_secs
(
2
));
}
launcher
.terminate
()
.unwrap
();
launcher
.wait
()
.unwrap
();
panic!
(
"failed to launch {}"
,
model_id
)
}
fn
test_model
(
model_id
:
String
,
num_shard
:
usize
,
port
:
usize
,
master_port
:
usize
,
)
->
GeneratedText
{
let
mut
launcher
=
start_launcher
(
model_id
,
num_shard
,
port
,
master_port
);
let
data
=
r#"
{
"inputs": "Test request",
"parameters": {
"details": true
}
}"#
;
let
req
:
Value
=
serde_json
::
from_str
(
data
)
.unwrap
();
let
client
=
reqwest
::
blocking
::
Client
::
new
();
let
res
=
client
.post
(
format!
(
"http://localhost:{}/generate"
,
port
))
.json
(
&
req
)
.send
();
launcher
.terminate
()
.unwrap
();
launcher
.wait
()
.unwrap
();
let
result
:
GeneratedText
=
res
.unwrap
()
.json
()
.unwrap
();
result
}
fn
read_json
(
name
:
&
str
)
->
GeneratedText
{
let
mut
d
=
PathBuf
::
from
(
env!
(
"CARGO_MANIFEST_DIR"
));
d
.push
(
"tests/"
);
d
.push
(
name
);
let
file
=
File
::
open
(
d
)
.unwrap
();
let
reader
=
BufReader
::
new
(
file
);
let
result
:
GeneratedText
=
serde_json
::
from_reader
(
reader
)
.unwrap
();
result
}
fn
compare_results
(
result
:
GeneratedText
,
expected
:
GeneratedText
)
{
assert_eq!
(
result
.generated_text
,
expected
.generated_text
);
assert_eq!
(
result
.details.finish_reason
,
expected
.details.finish_reason
);
assert_eq!
(
result
.details.generated_tokens
,
expected
.details.generated_tokens
);
for
(
token
,
expected_token
)
in
result
.details
.tokens
.into_iter
()
.zip
(
expected
.details.tokens
.into_iter
())
{
assert_eq!
(
token
.id
,
expected_token
.id
);
assert_eq!
(
token
.text
,
expected_token
.text
);
assert_eq!
(
token
.special
,
expected_token
.special
);
if
let
Some
(
logprob
)
=
token
.logprob
{
let
expected_logprob
=
expected_token
.logprob
.unwrap
();
assert_float_eq!
(
logprob
,
expected_logprob
,
abs
<=
0.001
);
}
else
{
assert_eq!
(
token
.logprob
,
expected_token
.logprob
);
}
}
}
#[test]
fn
test_bloom_560m
()
{
let
expected
=
read_json
(
"bloom_560m.json"
);
let
result
=
test_model
(
"bigscience/bloom-560m"
.to_string
(),
1
,
3000
,
29500
);
compare_results
(
result
,
expected
);
}
#[test]
fn
test_bloom_560m_distributed
()
{
let
expected
=
read_json
(
"bloom_560m.json"
);
let
result
=
test_model
(
"bigscience/bloom-560m"
.to_string
(),
2
,
3001
,
29501
);
compare_results
(
result
,
expected
);
}
#[test]
fn
test_mt0_base
()
{
let
expected
=
read_json
(
"mt0_base.json"
);
let
result
=
test_model
(
"bigscience/mt0-base"
.to_string
(),
1
,
3002
,
29502
);
compare_results
(
result
,
expected
);
}
launcher/tests/mt0_base.json
deleted
100644 → 0
View file @
f58f0a03
{
"generated_text"
:
"
\"\"\"
Test the contents of the contents of the contents.
\"\"\"
test_test"
,
"details"
:
{
"finish_reason"
:
"length"
,
"generated_tokens"
:
20
,
"seed"
:
null
,
"prefill"
:
[
{
"id"
:
0
,
"text"
:
"<pad>"
,
"logprob"
:
null
}
],
"tokens"
:
[
{
"id"
:
259
,
"text"
:
""
,
"logprob"
:
-1.3656927
,
"special"
:
false
},
{
"id"
:
215100
,
"text"
:
"
\"\"\"
"
,
"logprob"
:
-2.6551573
,
"special"
:
false
},
{
"id"
:
46138
,
"text"
:
"Test"
,
"logprob"
:
-1.8059857
,
"special"
:
false
},
{
"id"
:
287
,
"text"
:
" the"
,
"logprob"
:
-1.2102449
,
"special"
:
false
},
{
"id"
:
259
,
"text"
:
" "
,
"logprob"
:
-1.6057279
,
"special"
:
false
},
{
"id"
:
49076
,
"text"
:
"contents"
,
"logprob"
:
-3.6060903
,
"special"
:
false
},
{
"id"
:
304
,
"text"
:
" of"
,
"logprob"
:
-0.5270343
,
"special"
:
false
},
{
"id"
:
287
,
"text"
:
" the"
,
"logprob"
:
-0.62522805
,
"special"
:
false
},
{
"id"
:
259
,
"text"
:
" "
,
"logprob"
:
-1.4069618
,
"special"
:
false
},
{
"id"
:
49076
,
"text"
:
"contents"
,
"logprob"
:
-2.621994
,
"special"
:
false
},
{
"id"
:
304
,
"text"
:
" of"
,
"logprob"
:
-1.3172221
,
"special"
:
false
},
{
"id"
:
287
,
"text"
:
" the"
,
"logprob"
:
-0.3501925
,
"special"
:
false
},
{
"id"
:
259
,
"text"
:
" "
,
"logprob"
:
-0.7219573
,
"special"
:
false
},
{
"id"
:
49076
,
"text"
:
"contents"
,
"logprob"
:
-1.0494149
,
"special"
:
false
},
{
"id"
:
260
,
"text"
:
"."
,
"logprob"
:
-1.0803378
,
"special"
:
false
},
{
"id"
:
259
,
"text"
:
" "
,
"logprob"
:
-0.32933083
,
"special"
:
false
},
{
"id"
:
215100
,
"text"
:
"
\"\"\"
"
,
"logprob"
:
-0.11268901
,
"special"
:
false
},
{
"id"
:
2978
,
"text"
:
" test"
,
"logprob"
:
-1.5846587
,
"special"
:
false
},
{
"id"
:
290
,
"text"
:
"_"
,
"logprob"
:
-0.49796978
,
"special"
:
false
},
{
"id"
:
4125
,
"text"
:
"test"
,
"logprob"
:
-2.0026445
,
"special"
:
false
}
]
}
}
\ No newline at end of file
server/text_generation_server/models/bloom.py
View file @
e71471be
...
@@ -129,7 +129,7 @@ class BLOOMSharded(BLOOM):
...
@@ -129,7 +129,7 @@ class BLOOMSharded(BLOOM):
parameters
=
dict
(
model
.
named_parameters
())
parameters
=
dict
(
model
.
named_parameters
())
for
file
in
filenames
:
for
file
in
filenames
:
with
safe_open
(
with
safe_open
(
file
,
framework
=
"pt"
,
device
=
str
(
device
)
if
not
quantize
else
"cpu"
file
,
framework
=
"pt"
,
device
=
str
(
device
)
if
quantize
is
None
else
"cpu"
)
as
f
:
)
as
f
:
for
name
in
f
.
keys
():
for
name
in
f
.
keys
():
full_name
=
f
"transformer.
{
name
}
"
full_name
=
f
"transformer.
{
name
}
"
...
...
server/text_generation_server/models/custom_modeling/flash_llama_modeling.py
View file @
e71471be
...
@@ -21,16 +21,14 @@
...
@@ -21,16 +21,14 @@
import
torch
import
torch
import
torch.distributed
import
torch.distributed
from
torch.nn
import
functional
as
F
from
torch
import
nn
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
transformers.activations
import
ACT2FN
from
typing
import
Optional
from
typing
import
Optional
# Flash attention imports
# Flash attention imports
import
flash_attn_cuda
import
flash_attn_cuda
import
dropout_layer_norm
from
flash_attn.layers.rotary
import
RotaryEmbedding
from
text_generation_server.utils.layers
import
(
from
text_generation_server.utils.layers
import
(
FastLinear
,
FastLinear
,
TensorParallelRowLinear
,
TensorParallelRowLinear
,
...
@@ -332,15 +330,15 @@ class FlashLlamaModel(torch.nn.Module):
...
@@ -332,15 +330,15 @@ class FlashLlamaModel(torch.nn.Module):
self
.
head_size
=
self
.
layers
[
0
].
self_attn
.
head_size
self
.
head_size
=
self
.
layers
[
0
].
self_attn
.
head_size
self
.
num_heads
=
self
.
layers
[
0
].
self_attn
.
num_heads
self
.
num_heads
=
self
.
layers
[
0
].
self_attn
.
num_heads
def
post_load_weights
(
self
,
load_in_8bit
:
bool
=
Fals
e
):
def
post_load_weights
(
self
,
quantize
:
Optional
[
str
]
=
Non
e
):
if
isinstance
(
self
.
embed_tokens
,
TensorParallelEmbedding
):
if
isinstance
(
self
.
embed_tokens
,
TensorParallelEmbedding
):
self
.
embed_tokens
.
add_null_idx
()
self
.
embed_tokens
.
add_null_idx
()
for
layer
in
self
.
layers
:
for
layer
in
self
.
layers
:
layer
:
FlashLlamaLayer
layer
:
FlashLlamaLayer
layer
.
self_attn
.
query_key_value
.
prepare_weights
(
load_in_8bit
)
layer
.
self_attn
.
query_key_value
.
prepare_weights
(
quantize
)
layer
.
self_attn
.
o_proj
.
prepare_weights
(
load_in_8bit
)
layer
.
self_attn
.
o_proj
.
prepare_weights
(
quantize
)
layer
.
mlp
.
gate_up_proj
.
prepare_weights
(
load_in_8bit
)
layer
.
mlp
.
gate_up_proj
.
prepare_weights
(
quantize
)
layer
.
mlp
.
down_proj
.
prepare_weights
(
load_in_8bit
)
layer
.
mlp
.
down_proj
.
prepare_weights
(
quantize
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -429,8 +427,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
...
@@ -429,8 +427,8 @@ class FlashLlamaForCausalLM(torch.nn.Module):
else
:
else
:
self
.
lm_head
=
FastLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
self
.
lm_head
=
FastLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
def
post_load_weights
(
self
,
load_in_8bit
:
bool
=
Fals
e
):
def
post_load_weights
(
self
,
quantize
:
Optional
[
str
]
=
Non
e
):
self
.
model
.
post_load_weights
(
load_in_8bit
)
self
.
model
.
post_load_weights
(
quantize
)
self
.
lm_head
.
prepare_weights
()
self
.
lm_head
.
prepare_weights
()
def
forward
(
def
forward
(
...
...
server/text_generation_server/models/custom_modeling/flash_neox_modeling.py
View file @
e71471be
...
@@ -21,8 +21,6 @@
...
@@ -21,8 +21,6 @@
import
torch
import
torch
import
torch.distributed
import
torch.distributed
from
torch.nn
import
functional
as
F
from
torch
import
nn
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
transformers.activations
import
ACT2FN
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.modeling_utils
import
PreTrainedModel
...
@@ -32,7 +30,6 @@ from typing import Optional
...
@@ -32,7 +30,6 @@ from typing import Optional
# Flash attention imports
# Flash attention imports
import
flash_attn_cuda
import
flash_attn_cuda
from
flash_attn.layers.rotary
import
RotaryEmbedding
from
text_generation_server.utils.layers
import
(
from
text_generation_server.utils.layers
import
(
FastLinear
,
FastLinear
,
TensorParallelRowLinear
,
TensorParallelRowLinear
,
...
@@ -345,16 +342,16 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
...
@@ -345,16 +342,16 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
self
.
head_size
=
self
.
layers
[
0
].
attention
.
head_size
self
.
head_size
=
self
.
layers
[
0
].
attention
.
head_size
self
.
num_heads
=
self
.
layers
[
0
].
attention
.
num_heads
self
.
num_heads
=
self
.
layers
[
0
].
attention
.
num_heads
def
post_load_weights
(
self
,
load_in_8bit
=
Fals
e
):
def
post_load_weights
(
self
,
quantize
:
Optional
[
str
]
=
Non
e
):
if
isinstance
(
self
.
embed_in
,
TensorParallelEmbedding
):
if
isinstance
(
self
.
embed_in
,
TensorParallelEmbedding
):
self
.
embed_in
.
add_null_idx
()
self
.
embed_in
.
add_null_idx
()
for
layer
in
self
.
layers
:
for
layer
in
self
.
layers
:
layer
:
FlashNeoXLayer
layer
:
FlashNeoXLayer
layer
.
attention
.
shuffle_qkv_dims
()
layer
.
attention
.
shuffle_qkv_dims
()
layer
.
attention
.
query_key_value
.
prepare_weights
(
load_in_8bit
)
layer
.
attention
.
query_key_value
.
prepare_weights
(
quantize
)
layer
.
attention
.
dense
.
prepare_weights
(
load_in_8bit
)
layer
.
attention
.
dense
.
prepare_weights
(
quantize
)
layer
.
mlp
.
dense_h_to_4h
.
prepare_weights
(
load_in_8bit
)
layer
.
mlp
.
dense_h_to_4h
.
prepare_weights
(
quantize
)
layer
.
mlp
.
dense_4h_to_h
.
prepare_weights
(
load_in_8bit
)
layer
.
mlp
.
dense_4h_to_h
.
prepare_weights
(
quantize
)
@
classmethod
@
classmethod
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
def
from_pretrained
(
cls
,
pretrained_model_name_or_path
,
*
model_args
,
**
kwargs
):
...
@@ -457,8 +454,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
...
@@ -457,8 +454,8 @@ class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel):
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
)
def
post_load_weights
(
self
,
load_in_8bit
=
Fals
e
):
def
post_load_weights
(
self
,
quantize
:
Optional
[
str
]
=
Non
e
):
self
.
gpt_neox
.
post_load_weights
(
load_in_8bit
)
self
.
gpt_neox
.
post_load_weights
(
quantize
)
self
.
embed_out
.
prepare_weights
()
self
.
embed_out
.
prepare_weights
()
@
classmethod
@
classmethod
...
...
server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py
View file @
e71471be
import
torch
import
torch
import
torch.distributed
import
torch.distributed
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch
import
nn
from
transformers.activations
import
ACT2FN
from
transformers.activations
import
ACT2FN
from
typing
import
Optional
from
typing
import
Optional
...
@@ -261,16 +259,16 @@ class FlashSantacoderModel(nn.Module):
...
@@ -261,16 +259,16 @@ class FlashSantacoderModel(nn.Module):
self
.
head_size
=
self
.
h
[
0
].
attn
.
head_size
self
.
head_size
=
self
.
h
[
0
].
attn
.
head_size
self
.
num_heads
=
self
.
h
[
0
].
attn
.
num_heads
self
.
num_heads
=
self
.
h
[
0
].
attn
.
num_heads
def
post_load_weights
(
self
,
load_in_8bit
:
bool
=
Fals
e
):
def
post_load_weights
(
self
,
quantize
:
Optional
[
str
]
=
Non
e
):
if
self
.
tp_embeddings
:
if
self
.
tp_embeddings
:
self
.
wte
.
add_null_idx
()
self
.
wte
.
add_null_idx
()
self
.
wpe
.
add_null_idx
()
self
.
wpe
.
add_null_idx
()
for
layer
in
self
.
h
:
for
layer
in
self
.
h
:
layer
:
Block
layer
:
Block
layer
.
attn
.
c_attn
.
prepare_weights
(
load_in_8bit
)
layer
.
attn
.
c_attn
.
prepare_weights
(
quantize
)
layer
.
attn
.
c_proj
.
prepare_weights
(
load_in_8bit
)
layer
.
attn
.
c_proj
.
prepare_weights
(
quantize
)
layer
.
mlp
.
c_fc
.
prepare_weights
(
load_in_8bit
)
layer
.
mlp
.
c_fc
.
prepare_weights
(
quantize
)
layer
.
mlp
.
c_proj
.
prepare_weights
(
load_in_8bit
)
layer
.
mlp
.
c_proj
.
prepare_weights
(
quantize
)
def
forward
(
def
forward
(
self
,
self
,
...
@@ -347,8 +345,8 @@ class FlashSantacoderForCausalLM(nn.Module):
...
@@ -347,8 +345,8 @@ class FlashSantacoderForCausalLM(nn.Module):
else
:
else
:
self
.
lm_head
=
FastLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
self
.
lm_head
=
FastLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
def
post_load_weights
(
self
,
load_in_8bit
:
bool
=
Fals
e
):
def
post_load_weights
(
self
,
quantize
:
Optional
[
str
]
=
Non
e
):
self
.
transformer
.
post_load_weights
(
load_in_8bit
)
self
.
transformer
.
post_load_weights
(
quantize
)
self
.
lm_head
.
prepare_weights
()
self
.
lm_head
.
prepare_weights
()
def
forward
(
def
forward
(
...
...
server/text_generation_server/models/flash_llama.py
View file @
e71471be
...
@@ -28,7 +28,12 @@ tracer = trace.get_tracer(__name__)
...
@@ -28,7 +28,12 @@ tracer = trace.get_tracer(__name__)
class
FlashLlama
(
FlashCausalLM
):
class
FlashLlama
(
FlashCausalLM
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
=
False
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
):
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
dtype
=
torch
.
float16
dtype
=
torch
.
float16
...
@@ -72,14 +77,14 @@ class FlashLlama(FlashCausalLM):
...
@@ -72,14 +77,14 @@ class FlashLlama(FlashCausalLM):
def
load_weights
(
def
load_weights
(
model
,
model
,
filenames
:
List
[
Path
],
filenames
:
List
[
Path
],
quantize
:
bool
,
quantize
:
Optional
[
str
]
,
device
:
torch
.
device
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
):
):
for
filename
in
filenames
:
for
filename
in
filenames
:
state_dict
=
torch
.
load
(
filename
,
map_location
=
"cpu"
)
state_dict
=
torch
.
load
(
filename
,
map_location
=
"cpu"
)
for
key
,
value
in
state_dict
.
items
():
for
key
,
value
in
state_dict
.
items
():
value
=
value
.
to
(
device
if
not
quantize
else
"cpu"
).
to
(
dtype
)
value
=
value
.
to
(
device
if
quantize
is
None
else
"cpu"
).
to
(
dtype
)
layer_name
=
"."
.
join
(
key
.
split
(
"."
)[:
4
])
layer_name
=
"."
.
join
(
key
.
split
(
"."
)[:
4
])
...
@@ -199,7 +204,7 @@ class FlashLlamaSharded(FlashLlama):
...
@@ -199,7 +204,7 @@ class FlashLlamaSharded(FlashLlama):
def
load_weights
(
def
load_weights
(
model
,
model
,
filenames
:
List
[
str
],
filenames
:
List
[
str
],
quantize
:
bool
,
quantize
:
Optional
[
str
]
,
device
:
torch
.
device
,
device
:
torch
.
device
,
dtype
:
torch
.
dtype
,
dtype
:
torch
.
dtype
,
rank
:
int
,
rank
:
int
,
...
@@ -207,7 +212,7 @@ class FlashLlamaSharded(FlashLlama):
...
@@ -207,7 +212,7 @@ class FlashLlamaSharded(FlashLlama):
):
):
for
file
in
filenames
:
for
file
in
filenames
:
with
safe_open
(
with
safe_open
(
file
,
framework
=
"pt"
,
device
=
str
(
device
)
if
not
quantize
else
"cpu"
file
,
framework
=
"pt"
,
device
=
str
(
device
)
if
quantize
is
None
else
"cpu"
)
as
f
:
)
as
f
:
for
name
in
f
.
keys
():
for
name
in
f
.
keys
():
slice_
=
f
.
get_slice
(
name
)
slice_
=
f
.
get_slice
(
name
)
...
...
server/text_generation_server/models/flash_neox.py
View file @
e71471be
...
@@ -23,7 +23,12 @@ tracer = trace.get_tracer(__name__)
...
@@ -23,7 +23,12 @@ tracer = trace.get_tracer(__name__)
class
FlashNeoX
(
FlashCausalLM
):
class
FlashNeoX
(
FlashCausalLM
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
=
False
):
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
):
super
(
FlashNeoX
,
self
).
__init__
(
super
(
FlashNeoX
,
self
).
__init__
(
FlashGPTNeoXForCausalLM
,
model_id
,
revision
,
quantize
FlashGPTNeoXForCausalLM
,
model_id
,
revision
,
quantize
)
)
...
@@ -31,7 +36,10 @@ class FlashNeoX(FlashCausalLM):
...
@@ -31,7 +36,10 @@ class FlashNeoX(FlashCausalLM):
class
FlashNeoXSharded
(
FlashNeoX
):
class
FlashNeoXSharded
(
FlashNeoX
):
def
__init__
(
def
__init__
(
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
bool
=
False
self
,
model_id
:
str
,
revision
:
Optional
[
str
]
=
None
,
quantize
:
Optional
[
str
]
=
None
,
):
):
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
self
.
process_group
,
rank
,
world_size
=
initialize_torch_distributed
()
if
torch
.
cuda
.
is_available
():
if
torch
.
cuda
.
is_available
():
...
@@ -89,7 +97,7 @@ class FlashNeoXSharded(FlashNeoX):
...
@@ -89,7 +97,7 @@ class FlashNeoXSharded(FlashNeoX):
parameters
=
dict
(
model
.
named_parameters
())
parameters
=
dict
(
model
.
named_parameters
())
for
file
in
filenames
:
for
file
in
filenames
:
with
safe_open
(
with
safe_open
(
file
,
framework
=
"pt"
,
device
=
str
(
device
)
if
not
quantize
else
"cpu"
file
,
framework
=
"pt"
,
device
=
str
(
device
)
if
quantize
is
None
else
"cpu"
)
as
f
:
)
as
f
:
for
name
in
f
.
keys
():
for
name
in
f
.
keys
():
module_name
,
param_name
=
name
.
rsplit
(
"."
,
1
)
module_name
,
param_name
=
name
.
rsplit
(
"."
,
1
)
...
...
server/text_generation_server/models/flash_santacoder.py
View file @
e71471be
This diff is collapsed.
Click to expand it.
server/text_generation_server/models/galactica.py
View file @
e71471be
...
@@ -255,7 +255,7 @@ class GalacticaSharded(Galactica):
...
@@ -255,7 +255,7 @@ class GalacticaSharded(Galactica):
parameters
=
dict
(
model
.
named_parameters
())
parameters
=
dict
(
model
.
named_parameters
())
for
file
in
filenames
:
for
file
in
filenames
:
with
safe_open
(
with
safe_open
(
file
,
framework
=
"pt"
,
device
=
str
(
device
)
if
not
quantize
else
"cpu"
file
,
framework
=
"pt"
,
device
=
str
(
device
)
if
quantize
is
None
else
"cpu"
)
as
f
:
)
as
f
:
for
name
in
f
.
keys
():
for
name
in
f
.
keys
():
if
name
==
"lm_head.weight"
:
if
name
==
"lm_head.weight"
:
...
...
server/text_generation_server/models/gpt_neox.py
View file @
e71471be
...
@@ -94,7 +94,7 @@ class GPTNeoxSharded(CausalLM):
...
@@ -94,7 +94,7 @@ class GPTNeoxSharded(CausalLM):
parameters
=
dict
(
model
.
named_parameters
())
parameters
=
dict
(
model
.
named_parameters
())
for
file
in
filenames
:
for
file
in
filenames
:
with
safe_open
(
with
safe_open
(
file
,
framework
=
"pt"
,
device
=
str
(
device
)
if
not
quantize
else
"cpu"
file
,
framework
=
"pt"
,
device
=
str
(
device
)
if
quantize
is
None
else
"cpu"
)
as
f
:
)
as
f
:
for
name
in
f
.
keys
():
for
name
in
f
.
keys
():
module_name
,
param_name
=
name
.
rsplit
(
"."
,
1
)
module_name
,
param_name
=
name
.
rsplit
(
"."
,
1
)
...
...
server/text_generation_server/models/opt.py
View file @
e71471be
This diff is collapsed.
Click to expand it.
server/text_generation_server/models/t5.py
View file @
e71471be
This diff is collapsed.
Click to expand it.
server/text_generation_server/utils/layers.py
View file @
e71471be
This diff is collapsed.
Click to expand it.
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment