Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
text-generation-inference
Commits
44b267ab
Commit
44b267ab
authored
Dec 14, 2023
by
OlivierDehaene
Browse files
fix: fix gpt-q params loading
parent
28821bfd
Changes
11
Hide whitespace changes
Inline
Side-by-side
Showing
11 changed files
with
20 additions
and
14 deletions
+20
-14
server/text_generation_server/models/bloom.py
server/text_generation_server/models/bloom.py
+1
-1
server/text_generation_server/models/flash_llama.py
server/text_generation_server/models/flash_llama.py
+1
-1
server/text_generation_server/models/flash_mistral.py
server/text_generation_server/models/flash_mistral.py
+1
-1
server/text_generation_server/models/flash_neox.py
server/text_generation_server/models/flash_neox.py
+1
-1
server/text_generation_server/models/flash_rw.py
server/text_generation_server/models/flash_rw.py
+1
-1
server/text_generation_server/models/flash_santacoder.py
server/text_generation_server/models/flash_santacoder.py
+1
-1
server/text_generation_server/models/galactica.py
server/text_generation_server/models/galactica.py
+1
-1
server/text_generation_server/models/gpt_neox.py
server/text_generation_server/models/gpt_neox.py
+1
-1
server/text_generation_server/models/mpt.py
server/text_generation_server/models/mpt.py
+1
-1
server/text_generation_server/models/opt.py
server/text_generation_server/models/opt.py
+1
-1
server/text_generation_server/utils/weights.py
server/text_generation_server/utils/weights.py
+10
-4
No files found.
server/text_generation_server/models/bloom.py
View file @
44b267ab
...
@@ -81,7 +81,7 @@ class BLOOMSharded(CausalLM):
...
@@ -81,7 +81,7 @@ class BLOOMSharded(CausalLM):
prefix
=
"transformer"
,
prefix
=
"transformer"
,
)
)
if
config
.
quantize
==
"gptq"
:
if
config
.
quantize
==
"gptq"
:
weights
.
_set_gptq_params
(
model_id
)
weights
.
_set_gptq_params
(
model_id
,
revision
)
model
=
BloomForCausalLM
(
config
,
weights
)
model
=
BloomForCausalLM
(
config
,
weights
)
...
...
server/text_generation_server/models/flash_llama.py
View file @
44b267ab
...
@@ -64,7 +64,7 @@ class FlashLlama(FlashCausalLM):
...
@@ -64,7 +64,7 @@ class FlashLlama(FlashCausalLM):
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
weights
=
Weights
(
filenames
,
device
,
dtype
,
process_group
=
self
.
process_group
)
weights
=
Weights
(
filenames
,
device
,
dtype
,
process_group
=
self
.
process_group
)
if
config
.
quantize
in
[
"gptq"
,
"awq"
]:
if
config
.
quantize
in
[
"gptq"
,
"awq"
]:
weights
.
_set_gptq_params
(
model_id
)
weights
.
_set_gptq_params
(
model_id
,
revision
)
model
=
FlashLlamaForCausalLM
(
config
,
weights
)
model
=
FlashLlamaForCausalLM
(
config
,
weights
)
if
use_medusa
:
if
use_medusa
:
...
...
server/text_generation_server/models/flash_mistral.py
View file @
44b267ab
...
@@ -328,7 +328,7 @@ class BaseFlashMistral(FlashCausalLM):
...
@@ -328,7 +328,7 @@ class BaseFlashMistral(FlashCausalLM):
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
weights
=
Weights
(
filenames
,
device
,
dtype
,
process_group
=
self
.
process_group
)
weights
=
Weights
(
filenames
,
device
,
dtype
,
process_group
=
self
.
process_group
)
if
config
.
quantize
in
[
"gptq"
,
"awq"
]:
if
config
.
quantize
in
[
"gptq"
,
"awq"
]:
weights
.
_set_gptq_params
(
model_id
)
weights
.
_set_gptq_params
(
model_id
,
revision
)
model
=
model_cls
(
config
,
weights
)
model
=
model_cls
(
config
,
weights
)
...
...
server/text_generation_server/models/flash_neox.py
View file @
44b267ab
...
@@ -53,7 +53,7 @@ class FlashNeoXSharded(FlashCausalLM):
...
@@ -53,7 +53,7 @@ class FlashNeoXSharded(FlashCausalLM):
filenames
,
device
=
device
,
dtype
=
dtype
,
process_group
=
self
.
process_group
filenames
,
device
=
device
,
dtype
=
dtype
,
process_group
=
self
.
process_group
)
)
if
config
.
quantize
==
"gptq"
:
if
config
.
quantize
==
"gptq"
:
weights
.
_set_gptq_params
(
model_id
)
weights
.
_set_gptq_params
(
model_id
,
revision
)
model
=
FlashGPTNeoXForCausalLM
(
config
,
weights
)
model
=
FlashGPTNeoXForCausalLM
(
config
,
weights
)
...
...
server/text_generation_server/models/flash_rw.py
View file @
44b267ab
...
@@ -62,7 +62,7 @@ class FlashRWSharded(FlashCausalLM):
...
@@ -62,7 +62,7 @@ class FlashRWSharded(FlashCausalLM):
config
.
quantize
=
quantize
config
.
quantize
=
quantize
if
config
.
quantize
==
"gptq"
:
if
config
.
quantize
==
"gptq"
:
weights
.
_set_gptq_params
(
model_id
)
weights
.
_set_gptq_params
(
model_id
,
revision
)
model
=
FlashRWForCausalLM
(
config
,
weights
)
model
=
FlashRWForCausalLM
(
config
,
weights
)
...
...
server/text_generation_server/models/flash_santacoder.py
View file @
44b267ab
...
@@ -63,7 +63,7 @@ class FlashSantacoderSharded(FlashCausalLM):
...
@@ -63,7 +63,7 @@ class FlashSantacoderSharded(FlashCausalLM):
aliases
=
{
"transformer.wte.weight"
:
[
"lm_head.weight"
]},
aliases
=
{
"transformer.wte.weight"
:
[
"lm_head.weight"
]},
)
)
if
config
.
quantize
==
"gptq"
:
if
config
.
quantize
==
"gptq"
:
weights
.
_set_gptq_params
(
model_id
)
weights
.
_set_gptq_params
(
model_id
,
revision
)
model
=
FlashSantacoderForCausalLM
(
config
,
weights
)
model
=
FlashSantacoderForCausalLM
(
config
,
weights
)
...
...
server/text_generation_server/models/galactica.py
View file @
44b267ab
...
@@ -199,7 +199,7 @@ class GalacticaSharded(CausalLM):
...
@@ -199,7 +199,7 @@ class GalacticaSharded(CausalLM):
filenames
,
device
=
device
,
dtype
=
dtype
,
process_group
=
self
.
process_group
filenames
,
device
=
device
,
dtype
=
dtype
,
process_group
=
self
.
process_group
)
)
if
config
.
quantize
==
"gptq"
:
if
config
.
quantize
==
"gptq"
:
weights
.
_set_gptq_params
(
model_id
)
weights
.
_set_gptq_params
(
model_id
,
revision
)
model
=
OPTForCausalLM
(
config
,
weights
)
model
=
OPTForCausalLM
(
config
,
weights
)
...
...
server/text_generation_server/models/gpt_neox.py
View file @
44b267ab
...
@@ -57,7 +57,7 @@ class GPTNeoxSharded(CausalLM):
...
@@ -57,7 +57,7 @@ class GPTNeoxSharded(CausalLM):
filenames
,
device
=
device
,
dtype
=
dtype
,
process_group
=
self
.
process_group
filenames
,
device
=
device
,
dtype
=
dtype
,
process_group
=
self
.
process_group
)
)
if
config
.
quantize
==
"gptq"
:
if
config
.
quantize
==
"gptq"
:
weights
.
_set_gptq_params
(
model_id
)
weights
.
_set_gptq_params
(
model_id
,
revision
)
model
=
GPTNeoxForCausalLM
(
config
,
weights
)
model
=
GPTNeoxForCausalLM
(
config
,
weights
)
...
...
server/text_generation_server/models/mpt.py
View file @
44b267ab
...
@@ -81,7 +81,7 @@ class MPTSharded(CausalLM):
...
@@ -81,7 +81,7 @@ class MPTSharded(CausalLM):
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
filenames
=
weight_files
(
model_id
,
revision
=
revision
,
extension
=
".safetensors"
)
weights
=
Weights
(
filenames
,
device
,
dtype
,
process_group
=
self
.
process_group
)
weights
=
Weights
(
filenames
,
device
,
dtype
,
process_group
=
self
.
process_group
)
if
config
.
quantize
==
"gptq"
:
if
config
.
quantize
==
"gptq"
:
weights
.
_set_gptq_params
(
model_id
)
weights
.
_set_gptq_params
(
model_id
,
revision
)
config
.
quantize
=
quantize
config
.
quantize
=
quantize
model
=
MPTForCausalLM
(
config
,
weights
)
model
=
MPTForCausalLM
(
config
,
weights
)
...
...
server/text_generation_server/models/opt.py
View file @
44b267ab
...
@@ -55,7 +55,7 @@ class OPTSharded(CausalLM):
...
@@ -55,7 +55,7 @@ class OPTSharded(CausalLM):
filenames
,
device
=
device
,
dtype
=
dtype
,
process_group
=
self
.
process_group
filenames
,
device
=
device
,
dtype
=
dtype
,
process_group
=
self
.
process_group
)
)
if
config
.
quantize
==
"gptq"
:
if
config
.
quantize
==
"gptq"
:
weights
.
_set_gptq_params
(
model_id
)
weights
.
_set_gptq_params
(
model_id
,
revision
)
model
=
OPTForCausalLM
(
config
,
weights
)
model
=
OPTForCausalLM
(
config
,
weights
)
...
...
server/text_generation_server/utils/weights.py
View file @
44b267ab
...
@@ -327,13 +327,15 @@ class Weights:
...
@@ -327,13 +327,15 @@ class Weights:
return
bits
,
groupsize
return
bits
,
groupsize
def
_set_gptq_params
(
self
,
model_id
):
def
_set_gptq_params
(
self
,
model_id
,
revision
):
filename
=
"config.json"
filename
=
"config.json"
try
:
try
:
if
os
.
path
.
exists
(
os
.
path
.
join
(
model_id
,
filename
)):
if
os
.
path
.
exists
(
os
.
path
.
join
(
model_id
,
filename
)):
filename
=
os
.
path
.
join
(
model_id
,
filename
)
filename
=
os
.
path
.
join
(
model_id
,
filename
)
else
:
else
:
filename
=
hf_hub_download
(
model_id
,
filename
=
filename
)
filename
=
hf_hub_download
(
model_id
,
filename
=
filename
,
revision
=
revision
)
with
open
(
filename
,
"r"
)
as
f
:
with
open
(
filename
,
"r"
)
as
f
:
data
=
json
.
load
(
f
)
data
=
json
.
load
(
f
)
self
.
gptq_bits
=
data
[
"quantization_config"
][
"bits"
]
self
.
gptq_bits
=
data
[
"quantization_config"
][
"bits"
]
...
@@ -344,7 +346,9 @@ class Weights:
...
@@ -344,7 +346,9 @@ class Weights:
if
os
.
path
.
exists
(
os
.
path
.
join
(
model_id
,
filename
)):
if
os
.
path
.
exists
(
os
.
path
.
join
(
model_id
,
filename
)):
filename
=
os
.
path
.
join
(
model_id
,
filename
)
filename
=
os
.
path
.
join
(
model_id
,
filename
)
else
:
else
:
filename
=
hf_hub_download
(
model_id
,
filename
=
filename
)
filename
=
hf_hub_download
(
model_id
,
filename
=
filename
,
revision
=
revision
)
with
open
(
filename
,
"r"
)
as
f
:
with
open
(
filename
,
"r"
)
as
f
:
data
=
json
.
load
(
f
)
data
=
json
.
load
(
f
)
self
.
gptq_bits
=
data
[
"bits"
]
self
.
gptq_bits
=
data
[
"bits"
]
...
@@ -355,7 +359,9 @@ class Weights:
...
@@ -355,7 +359,9 @@ class Weights:
if
os
.
path
.
exists
(
os
.
path
.
join
(
model_id
,
filename
)):
if
os
.
path
.
exists
(
os
.
path
.
join
(
model_id
,
filename
)):
filename
=
os
.
path
.
join
(
model_id
,
filename
)
filename
=
os
.
path
.
join
(
model_id
,
filename
)
else
:
else
:
filename
=
hf_hub_download
(
model_id
,
filename
=
filename
)
filename
=
hf_hub_download
(
model_id
,
filename
=
filename
,
revision
=
revision
)
with
open
(
filename
,
"r"
)
as
f
:
with
open
(
filename
,
"r"
)
as
f
:
data
=
json
.
load
(
f
)
data
=
json
.
load
(
f
)
self
.
gptq_bits
=
data
[
"w_bit"
]
self
.
gptq_bits
=
data
[
"w_bit"
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment